1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/memory_tracking.hpp"
19#include "common/nstl.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/platform.hpp"
24#include "cpu/scale_utils.hpp"
25#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
26#include "cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp"
27
28#define GET_OFF(field) offsetof(jit_conv_call_s, field)
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35using namespace dnnl::impl::memory_tracking::names;
36using namespace dnnl::impl::data_type;
37using namespace dnnl::impl::utils;
38using namespace Xbyak;
39
40jit_avx512_core_amx_1x1_fwd_kernel_t::jit_avx512_core_amx_1x1_fwd_kernel_t(
41 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
42 const memory_desc_t &dst_md)
43 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
44 , jcp(ajcp)
45 , attr_(attr) {
46 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
47 using namespace binary_injector;
48 const auto &rhs_addr_reg = bin_injector_helper_reg_1;
49 const auto &rhs_helper_reg = bin_injector_helper_reg_2;
50 const auto &rhs_addr_cache_reg = bin_injector_helper_reg_3;
51 static constexpr bool preserve_gpr = false;
52 static constexpr bool preserve_vmm = false;
53 const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
54 static constexpr bool use_exact_tail_scalar_bcast = true;
55
56 const rhs_arg_static_params_t rhs_arg_static_params {31, rhs_addr_reg,
57 rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr, preserve_vmm,
58 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
59 memory_desc_wrapper(dst_md), tail_size, ktail_mask,
60 use_exact_tail_scalar_bcast};
61 const static_params_t static_params {
62 this->param1, rhs_arg_static_params};
63
64 postops_injector_ = utils::make_unique<
65 injector::jit_uni_postops_injector_t<avx512_core>>(
66 this, jcp.post_ops, static_params);
67 }
68}
69
70// Tile register decomposition
71int jit_avx512_core_amx_1x1_fwd_kernel_t::get_out_tensor(int h, int i) const {
72 return C_BASE + h * jcp.nb_os_blocking + i;
73}
74int jit_avx512_core_amx_1x1_fwd_kernel_t::get_inp_tensor(int h) const {
75 return I_BASE + h;
76}
77int jit_avx512_core_amx_1x1_fwd_kernel_t::get_wei_tensor(int i) const {
78 return W_BASE + i;
79}
80
81bool jit_avx512_core_amx_1x1_fwd_kernel_t::is_bf16() const {
82 return jcp.src_dt == data_type::bf16;
83}
84
85// Code generation
86void jit_avx512_core_amx_1x1_fwd_kernel_t::init_runtime_counters() {
87 row_count_ = 0;
88 buf_count_ = 0;
89 is_store_done_ = false;
90 is_buffer_empty_ = true;
91}
92
93size_t jit_avx512_core_amx_1x1_fwd_kernel_t::out_h_shift() const {
94 return (size_t)jcp.ow * jcp.ngroups * jcp.oc_without_padding;
95}
96
97size_t jit_avx512_core_amx_1x1_fwd_kernel_t::out_w_shift() const {
98 return (size_t)jcp.ngroups * jcp.oc_without_padding;
99}
100
101size_t jit_avx512_core_amx_1x1_fwd_kernel_t::inp_offset(
102 int h, int w, int icb) const {
103 return (size_t)jcp.typesize_in
104 * (h * jcp.iw * jcp.ngroups * jcp.ic_without_padding
105 + w * jcp.ngroups * jcp.ic_without_padding
106 + icb * jcp.ic_block_int_np);
107}
108
109size_t jit_avx512_core_amx_1x1_fwd_kernel_t::out_row_offset(
110 int h, int w, int ocb) const {
111 return (size_t)jcp.typesize_out
112 * (h * jcp.ow * jcp.ngroups * jcp.oc_without_padding
113 + w * jcp.ngroups * jcp.oc_without_padding
114 + ocb * jcp.oc_block);
115}
116
117void jit_avx512_core_amx_1x1_fwd_kernel_t::update_buffer_pointers() {
118 auto buffer_offset = [=](bool shift) { return ((buf_count_ + shift) % 2); };
119 int wsp_shift = jcp.typesize_acc * (jcp.wsp_buffer_size / 2);
120
121 int postop_shift = wsp_shift * buffer_offset(true);
122
123 mov(reg_postop, wsp_ptr);
124 add(reg_postop, postop_shift);
125
126 buf_count_++;
127}
128
129void jit_avx512_core_amx_1x1_fwd_kernel_t::interleave_store() {
130 int scnd_dim = jcp.nb_os_blocking * jcp.tile_width;
131
132 for (int c = 0;
133 c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
134 c++) {
135 int ocb = (row_count_ / scnd_dim);
136 int osb = (row_count_ % scnd_dim) / jcp.tile_width;
137 int row = (row_count_ % scnd_dim) % jcp.tile_width;
138
139 const Zmm zmm_r = zmm_out(row);
140
141 int oh = ((osb * jcp.tile_width + row) / jcp.ow);
142 int ow = ((osb * jcp.tile_width + row) % jcp.ow);
143
144 {
145 // preserve registers used by binary post_ops injector
146 const injector_utils::conditional_register_preserve_guard_t
147 cond_register_guard(jcp.with_binary, this,
148 {bin_injector_helper_reg_1,
149 bin_injector_helper_reg_2,
150 bin_injector_helper_reg_3});
151 const int wsp_row_offset = jcp.typesize_acc
152 * (osb * jcp.nb_oc_blocking * jcp.max_width * jcp.oc_block
153 + ocb * jcp.max_width * jcp.oc_block
154 + row * jcp.oc_block);
155
156 vmovups(zmm_r, ptr[reg_postop + wsp_row_offset]);
157 store_output_vector(zmm_r, ocb, oh, ow);
158 row_count_++;
159 }
160
161 int exp_row_count
162 = jcp.tile_width * jcp.nb_oc_blocking * jcp.nb_os_blocking;
163 if (row_count_ == exp_row_count) {
164 int oh = ((jcp.nb_os_blocking * jcp.tile_width) / jcp.ow);
165 int ow = ((jcp.nb_os_blocking * jcp.tile_width) % jcp.ow);
166 size_t out_offset = jcp.typesize_out
167 * (oh * out_h_shift() + ow * out_w_shift());
168 add(out_ptr, out_offset);
169 row_count_ = 0;
170 is_store_done_ = true;
171 }
172 }
173}
174
175Ymm jit_avx512_core_amx_1x1_fwd_kernel_t::ymm_mask(
176 const Ymm ymm_in, bool mask_flag, bool store) {
177 return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
178 : ymm_in;
179}
180
181Zmm jit_avx512_core_amx_1x1_fwd_kernel_t::zmm_mask(
182 const Zmm zmm_in, bool mask_flag, bool store) {
183 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
184 : zmm_in;
185}
186
187void jit_avx512_core_amx_1x1_fwd_kernel_t::cvt2ps(data_type_t type_in,
188 const Zmm zmm_in, const Operand &op, bool mask_flag = false) {
189 using namespace dnnl::impl::data_type;
190 const Zmm zmm = zmm_mask(zmm_in, mask_flag);
191 switch (type_in) {
192 case bf16:
193 vpmovzxwd(zmm, op);
194 vpslld(zmm_in, zmm_in, 16);
195 break;
196 case f32:
197 case s32: vmovups(zmm, op); break;
198 case s8: vpmovsxbd(zmm, op); break;
199 case u8: vpmovzxbd(zmm, op); break;
200 default: assert(!"unsupported data type");
201 }
202 if (utils::one_of(type_in, s32, s8, u8)) vcvtdq2ps(zmm_in, zmm_in);
203}
204
205void jit_avx512_core_amx_1x1_fwd_kernel_t::apply_sum(const Zmm zmm_out,
206 const float *p_sum_scale, const int32_t *p_sum_zp,
207 const Xbyak::Address &addr, const bool mask_flag) {
208 if (p_sum_scale) {
209 const auto p_sum_scale_val = *p_sum_scale;
210 const auto p_sum_zp_val = *p_sum_zp;
211 const auto sum_injector = [&, zmm_out, p_sum_scale_val, p_sum_zp_val,
212 mask_flag]() {
213 cvt2ps(jcp.sum_dt, zmm_prev_dst, addr, mask_flag);
214 if (p_sum_zp_val != 0) {
215 vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
216 vsubps(zmm_prev_dst, zmm_sum_zp);
217 }
218 if (p_sum_scale_val == 1.f)
219 vaddps(zmm_out, zmm_prev_dst);
220 else
221 vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
222 };
223 postops_injector_->set_lambda_injector(
224 primitive_kind::sum, sum_injector);
225 }
226}
227
228void jit_avx512_core_amx_1x1_fwd_kernel_t::apply_postops(const Zmm zmm_out,
229 const float *p_sum_scale, const int32_t *p_sum_zp,
230 const Xbyak::Address &addr, const size_t off, const bool mask_flag) {
231 if (jcp.with_eltwise || jcp.with_binary
232 || (jcp.with_sum && p_sum_scale != nullptr)) {
233 apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag);
234
235 const auto vmm_idx = zmm_out.getIdx();
236 if (jcp.with_binary) {
237 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
238 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, out_ptr);
239 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off);
240 if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
241
242 postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
243 } else {
244 postops_injector_->compute_vector(vmm_idx);
245 }
246 }
247}
248
249bool jit_avx512_core_amx_1x1_fwd_kernel_t::is_fast_postops(
250 const jit_conv_conf_t &jcp) {
251 const auto &p = jcp.post_ops;
252 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
253 auto is_sum = [&](int idx) {
254 const bool require_scale_one = jcp.src_dt == data_type::bf16;
255 return p.entry_[idx].is_sum(require_scale_one);
256 };
257 switch (p.len()) {
258 case 0: return true;
259 case 1: return is_relu(0) || is_sum(0);
260 case 2: return is_sum(0) && is_relu(1);
261 default: return false;
262 }
263 return false;
264}
265
266inline void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_ymm_bf16(
267 const int idx, const Xbyak::Address &addr, const bool mask_flag) {
268 Ymm ymm_out = Ymm(idx);
269 vcvtneps2bf16(ymm_out, Zmm(idx));
270 vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
271}
272
273void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vectors_int8(
274 int ocb, int osb) {
275 const bool mask_flag
276 = last_oc_block_flag_ && ocb == (jcp.nb_oc_blocking - 1);
277 const auto &p = attr_.post_ops_;
278 const int sum_idx = p.find(primitive_kind::sum);
279 const float *p_sum_scale = nullptr;
280 const int32_t *p_sum_zp = nullptr;
281 if (sum_idx != -1) {
282 const auto &p_entry = p.entry_[sum_idx];
283 p_sum_scale = &p_entry.sum.scale;
284 p_sum_zp = &p_entry.sum.zero_point;
285 }
286 if (p_sum_scale) {
287 if (*p_sum_scale != 1.f)
288 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
289 if (*p_sum_zp != 0)
290 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
291 }
292
293 if (jcp.src_zero_point) {
294 const int zp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
295 const Zmm zmm_zp_m = zmm_mask(zmm_zp, mask_flag);
296 vpmulld(zmm_zp_m, zmm_src_zp,
297 EVEX_compress_addr(reg_zp_compensation, zp_offset));
298 for (int j = 0; j < jcp.tile_width; j++) {
299 const Zmm zmm_r = zmm_out(j);
300 vpaddd(zmm_r, zmm_r, zmm_zp_m);
301 }
302 }
303
304 for (int j = 0; j < jcp.tile_width; j++) {
305 const Zmm zmm_r = zmm_out(j);
306 vcvtdq2ps(zmm_r, zmm_r);
307 }
308
309 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
310 for (int j = 0; j < jcp.tile_width; j++) {
311 const int scale_offset
312 = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
313 const Zmm zmm_r = zmm_out(j);
314 const Zmm zmm_r_msk = zmm_mask(zmm_r, mask_flag);
315 vmulps(zmm_r_msk, zmm_r,
316 EVEX_compress_addr(reg_ptr_scales, scale_offset));
317 }
318
319 if (jcp.with_bias) {
320 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
321 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
322 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
323 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
324 for (int j = 0; j < jcp.tile_width; j++) {
325 const Zmm zmm_r = zmm_out(j);
326 vaddps(zmm_r, zmm_r, zmm_bias);
327 }
328 }
329
330 if (p_sum_zp && *p_sum_zp != 0)
331 vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
332 if (jcp.with_sum && p_sum_scale != nullptr) {
333 const auto p_sum_scale_val = *p_sum_scale;
334 const auto p_sum_zp_val = *p_sum_zp;
335 for (int j = 0; j < jcp.tile_width; j++) {
336 int h = ((osb * jcp.tile_width + j) / jcp.ow);
337 int w = ((osb * jcp.tile_width + j) % jcp.ow);
338
339 const auto off = out_row_offset(h, w, ocb);
340 const auto addr = EVEX_compress_addr(out_ptr, off);
341
342 const Zmm zmm_r = zmm_out(j);
343 cvt2ps(jcp.sum_dt, zmm_prev_dst, addr, mask_flag);
344 if (p_sum_zp_val != 0) vsubps(zmm_prev_dst, zmm_sum_zp);
345 if (p_sum_scale_val == 1.f)
346 vaddps(zmm_r, zmm_prev_dst);
347 else
348 vfmadd231ps(zmm_r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
349 }
350 }
351 if (jcp.with_eltwise) {
352 vxorps(zmm_zero, zmm_zero, zmm_zero);
353 for (int j = 0; j < jcp.tile_width; j++) {
354 const Zmm zmm_r = zmm_out(j);
355 vmaxps(zmm_r, zmm_r, zmm_zero);
356 }
357 }
358
359 if (jcp.dst_scale) {
360 mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]);
361 for (int j = 0; j < jcp.tile_width; j++) {
362 const Zmm zmm_r = zmm_out(j);
363 const Zmm zmm_r_msk = zmm_mask(zmm_r, mask_flag);
364 vmulps(zmm_r_msk, zmm_r, EVEX_compress_addr(reg_ptr_dst_scale, 0));
365 }
366 }
367
368 if (jcp.dst_zero_point) {
369 for (int j = 0; j < jcp.tile_width; j++) {
370 const Zmm zmm_r = zmm_out(j);
371 vaddps(zmm_r, zmm_r, zmm_dst_zp);
372 }
373 }
374
375 // Properly saturate the accumulators for integer datatypes
376 if (one_of(jcp.dst_dt, u8, s8, s32)) {
377 init_saturate_f32(
378 zmm_zero, zmm_saturation, aux_reg_saturation, f32, jcp.dst_dt);
379 for (int j = 0; j < jcp.tile_width; j++) {
380 const Zmm zmm_r = zmm_out(j);
381 saturate_f32(zmm_r, zmm_zero, zmm_saturation, jcp.dst_dt);
382 vcvtps2dq(zmm_r, zmm_r);
383 }
384 }
385
386 for (int j = 0; j < jcp.tile_width; j++) {
387 const int h = ((osb * jcp.tile_width + j) / jcp.ow);
388 const int w = ((osb * jcp.tile_width + j) % jcp.ow);
389 const auto off = out_row_offset(h, w, ocb);
390 const auto addr = EVEX_compress_addr(out_ptr, off);
391
392 const Zmm zmm_out_store = zmm_mask(zmm_out(j), mask_flag, true);
393 switch (jcp.dst_dt) {
394 case data_type::f32:
395 case data_type::s32: vmovups(addr, zmm_out_store); break;
396 case data_type::bf16:
397 store_output_ymm_bf16(zmm_out_store.getIdx(), addr, mask_flag);
398 break;
399 case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
400 case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
401 default: assert(!"unknown dst_dt");
402 }
403 }
404}
405
406void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector_int8(
407 const Zmm zmm_out, int ocb, int h, int w) {
408
409 const auto off = out_row_offset(h, w, ocb);
410 const auto addr = EVEX_compress_addr(out_ptr, off);
411
412 const bool mask_flag
413 = last_oc_block_flag_ && ocb == (jcp.nb_oc_blocking - 1);
414 const auto &p = attr_.post_ops_;
415 const int sum_idx = p.find(primitive_kind::sum);
416 const float *p_sum_scale = nullptr;
417 const int32_t *p_sum_zp = nullptr;
418 if (sum_idx != -1) {
419 const auto &p_entry = p.entry_[sum_idx];
420 p_sum_scale = &p_entry.sum.scale;
421 p_sum_zp = &p_entry.sum.zero_point;
422 }
423
424 if (p_sum_scale) {
425 if (*p_sum_scale != 1.f)
426 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
427 if (*p_sum_zp != 0)
428 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
429 }
430
431 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
432 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
433
434 int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
435 if (jcp.with_bias) {
436 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
437 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
438 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
439 }
440 if (jcp.src_zero_point) {
441 const int zp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
442 const Zmm zmm_zp_m = zmm_mask(zmm_zp, mask_flag);
443 vpmulld(zmm_zp_m, zmm_src_zp,
444 EVEX_compress_addr(reg_zp_compensation, zp_offset));
445 vpaddd(zmm_out, zmm_out, zmm_zp_m);
446 }
447 /* add to zmm_accum: compensation, bias and permute */
448 vcvtdq2ps(zmm_out, zmm_out);
449
450 const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
451 vmulps(zmm_out_msk, zmm_out,
452 EVEX_compress_addr(reg_ptr_scales, scale_offset));
453
454 if (jcp.with_bias) vaddps(zmm_out_msk, zmm_out, zmm_bias);
455
456 apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag);
457
458 if (jcp.dst_scale) {
459 mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]);
460 vmulps(zmm_out, zmm_out, EVEX_compress_addr(reg_ptr_dst_scale, 0));
461 }
462 if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); }
463
464 // Properly saturate the accumulators for integer datatypes
465 if (one_of(jcp.dst_dt, u8, s8, s32)) {
466 init_saturate_f32(
467 zmm_zero, zmm_saturation, aux_reg_saturation, f32, jcp.dst_dt);
468 saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dst_dt);
469 vcvtps2dq(zmm_out, zmm_out);
470 }
471
472 const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
473 switch (jcp.dst_dt) {
474 case data_type::f32:
475 case data_type::s32: vmovups(addr, zmm_out_store); break;
476 case data_type::bf16:
477 store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag);
478 break;
479 case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
480 case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
481 default: assert(!"unknown dst_dt");
482 }
483}
484
485void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vectors_bf16(
486 int ocb, int osb) {
487 const bool mask_flag
488 = last_oc_block_flag_ && ocb == (jcp.nb_oc_blocking - 1);
489
490 if (jcp.with_bias) {
491 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
492 const int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
493 const auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
494 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
495 for (int j = 0; j < jcp.tile_width; j++) {
496 const Zmm zmm_r = zmm_out(j);
497 vaddps(zmm_r, zmm_r, zmm_bias);
498 }
499 }
500
501 if (jcp.with_sum) {
502 for (int j = 0; j < jcp.tile_width; j++) {
503 int h = ((osb * jcp.tile_width + j) / jcp.ow);
504 int w = ((osb * jcp.tile_width + j) % jcp.ow);
505 const auto off = out_row_offset(h, w, ocb);
506 const auto addr = EVEX_compress_addr(out_ptr, off);
507 const Zmm zmm_r = zmm_out(j);
508 cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
509 vaddps(zmm_r, zmm_prev_dst);
510 }
511 }
512 if (jcp.with_eltwise) {
513 vxorps(zmm_zero, zmm_zero, zmm_zero);
514 for (int j = 0; j < jcp.tile_width; j++) {
515 const Zmm zmm_r = zmm_out(j);
516 vmaxps(zmm_r, zmm_r, zmm_zero);
517 }
518 }
519
520 for (int j = 0; j < jcp.tile_width; j++) {
521 const int h = ((osb * jcp.tile_width + j) / jcp.ow);
522 const int w = ((osb * jcp.tile_width + j) % jcp.ow);
523 const auto off = out_row_offset(h, w, ocb);
524 const auto addr = EVEX_compress_addr(out_ptr, off);
525 const Zmm zmm_r = zmm_out(j);
526 if (jcp.dst_dt == data_type::bf16) {
527 store_output_ymm_bf16(zmm_r.getIdx(), addr, mask_flag);
528 } else {
529 vmovups(addr, zmm_mask(zmm_r, mask_flag, true));
530 }
531 }
532}
533
534void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector_bf16(
535 const Zmm zmm_out, int ocb, int h, int w) {
536 const auto off = out_row_offset(h, w, ocb);
537 const auto addr = EVEX_compress_addr(out_ptr, off);
538
539 const bool mask_flag
540 = last_oc_block_flag_ && ocb == (jcp.nb_oc_blocking - 1);
541
542 const auto &p = attr_.post_ops_;
543
544 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
545
546 const int sum_idx = p.find(primitive_kind::sum);
547 if (sum_idx != -1) {
548 if (jcp.dst_dt == data_type::bf16) {
549 vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
550 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
551 vaddps(zmm_out, zmm_prev_dst);
552 } else {
553 vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
554 vaddps(zmm_out, zmm_prev_dst);
555 }
556 }
557 if (jcp.with_bias) {
558 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
559 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
560 if (jcp.bia_dt == data_type::bf16) {
561 vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
562 vpslld(zmm_bias, zmm_bias, 16);
563 vaddps(zmm_out, zmm_bias);
564 } else
565 vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
566 }
567
568 static constexpr auto skip_sum_in_injection = nullptr;
569 apply_postops(zmm_out, skip_sum_in_injection, skip_sum_in_injection, addr,
570 off, mask_flag);
571
572 if (jcp.dst_dt == data_type::bf16) {
573 store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag);
574 } else {
575 vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
576 }
577}
578
579// Store all rows of a tile
580void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vectors(
581 int ocb, int osb) {
582 if (is_bf16()) {
583 store_output_vectors_bf16(ocb, osb);
584 } else {
585 store_output_vectors_int8(ocb, osb);
586 }
587}
588
589// Store single row
590void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector(
591 const Zmm zmm_out, int ocb, int h, int w) {
592 if (is_bf16()) {
593 store_output_vector_bf16(zmm_out, ocb, h, w);
594 } else {
595 store_output_vector_int8(zmm_out, ocb, h, w);
596 }
597}
598
599void jit_avx512_core_amx_1x1_fwd_kernel_t::prepare_output() {
600 for (int osb = 0; osb < jcp.nb_os_blocking; osb++)
601 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
602 tilezero(Tmm(get_out_tensor(osb, ocb)));
603}
604
605void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output(
606 bool do_store, bool has_tail) {
607
608 auto store_output_subblock = [=](int ocb, int osb) {
609 const int wsp_offset = jcp.typesize_acc
610 * (osb * jcp.nb_oc_blocking * jcp.max_width * jcp.oc_block
611 + ocb * jcp.max_width * jcp.oc_block);
612 tilestored(ptr[wsp_ptr + stride_seq + wsp_offset],
613 Tmm(get_out_tensor(osb, ocb)));
614
615 // preserve registers used by binary post_ops injector
616 const injector_utils::conditional_register_preserve_guard_t
617 cond_register_guard(jcp.with_binary, this,
618 {bin_injector_helper_reg_1, bin_injector_helper_reg_2});
619 is_buffer_empty_ = false;
620 is_store_done_ = (do_store) ? true : false;
621 for (int j = 0; j < jcp.tile_width && do_store; j++) {
622 int oh_ = ((osb * jcp.tile_width + j) / jcp.ow);
623 int ow_ = ((osb * jcp.tile_width + j) % jcp.ow);
624
625 auto addr = ptr[wsp_ptr + jcp.typesize_acc * (j * jcp.oc_block)
626 + wsp_offset];
627 const Zmm zmm_r = zmm_out(j);
628 vmovups(zmm_r, addr);
629 if (!jcp.is_fast_postops) store_output_vector(zmm_r, ocb, oh_, ow_);
630 }
631 if (do_store && jcp.is_fast_postops) store_output_vectors(ocb, osb);
632 };
633
634 auto store_output_block = [=](int os_b = 1) {
635 if (jcp.src_zero_point) {
636 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
637 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
638 vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
639 }
640 if (jcp.dst_zero_point) {
641 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
642 vcvtdq2ps(zmm_dst_zp,
643 EVEX_compress_addr(reg_dst_zero_point, 0, true));
644 }
645 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
646 for (int osb = 0; osb < os_b; osb++)
647 store_output_subblock(ocb, osb);
648 };
649
650 Label label_oc_store, label_done;
651
652 if (check_last_sb_) {
653 mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
654 cmp(reg_last_h, 1);
655 je(label_oc_store, T_NEAR);
656 }
657
658 store_output_block(jcp.nb_os_blocking);
659 jmp(label_done, T_NEAR);
660
661 L(label_oc_store);
662 store_output_block();
663
664 L(label_done);
665 update_buffer_pointers();
666}
667
668void jit_avx512_core_amx_1x1_fwd_kernel_t::icb_loop(bool do_store) {
669 enum tiles_cfg_t { cfg_tiles, cfg_tiles_tail };
670 enum restore_tiles_t { write_tiles, read_tiles };
671
672 auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
673 if (jcp.src_dt == data_type::bf16 && jcp.wei_dt == data_type::bf16) {
674 tdpbf16ps(x1, x2, x3);
675 } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::u8) {
676 tdpbuud(x1, x2, x3);
677 } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::s8) {
678 tdpbusd(x1, x2, x3);
679 } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::u8) {
680 tdpbsud(x1, x2, x3);
681 } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::s8) {
682 tdpbssd(x1, x2, x3);
683 } else {
684 assert(!"unsupported combination");
685 }
686 };
687
688 auto tileloadd_nt = [=](const Tmm &t1, int offset) {
689 int ab_size = jcp.nb_os2_blocking * jcp.nb_os_blocking * jcp.tile_width
690 * (jcp.nb_ic_int * jcp.ic_block_int_np
691 + jcp.nb_oc_blocking * jcp.oc_block);
692 int c_size = (jcp.nb_ic_int * jcp.ic_block_int_np * jcp.nb_oc_blocking
693 * jcp.oc_block);
694 // If the size of src + wei used in the kernel cannot fit into L1 cache,
695 // use non-temporal load of weights to help keep src in L1 cache
696 if (static_cast<size_t>(jcp.typesize_in * (ab_size + c_size))
697 >= platform::get_per_core_cache_size(1))
698 tileloaddt1(t1, ptr[wei_ptr + offset + stride_seq]);
699 else
700 tileloadd(t1, ptr[wei_ptr + offset + stride_seq]);
701 };
702
703 auto compute_block = [=](int icb, int os_b) {
704 for (int osb = 0; osb < os_b; osb++) {
705 int ih = ((osb * jcp.tile_width) / jcp.ow) * jcp.stride_h;
706 int iw = ((osb * jcp.tile_width) % jcp.ow) * jcp.stride_w;
707 tileloadd(Tmm(get_inp_tensor(osb)),
708 ptr[inp_ptr + stride_nhwc + inp_offset(ih, iw, icb)]);
709 }
710 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
711 const int wei_offset = jcp.typesize_in
712 * (ocb
713 * utils::rnd_up(jcp.ic_without_padding,
714 jcp.ic_block_int)
715 * jcp.oc_block
716 + icb * jcp.ic_block_int_np * jcp.oc_block);
717 tileloadd_nt(Tmm(get_wei_tensor(ocb)), wei_offset);
718 for (int osb = 0; osb < os_b; osb++) {
719 tdpbxxd(Tmm(get_out_tensor(osb, ocb)), Tmm(get_inp_tensor(osb)),
720 Tmm(get_wei_tensor(ocb)));
721 interleave_store();
722 }
723 }
724 };
725
726 auto reconfig_tiles = [=](tiles_cfg_t cfg) {
727 tilerelease();
728 if (cfg == cfg_tiles) {
729 mov(reg_scratch, ptr[param1 + GET_OFF(tile_cfg)]);
730 } else if (cfg == cfg_tiles_tail) {
731 mov(reg_scratch, ptr[param1 + GET_OFF(tile_cfg_tail)]);
732 }
733 ldtilecfg(ptr[reg_scratch]);
734 };
735
736 auto restore_output_tiles = [=](int os_b, restore_tiles_t restore) {
737 mov(reg_tilebuff, ptr[param1 + GET_OFF(src_prf)]);
738 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
739 for (int osb = 0; osb < os_b; osb++) {
740 const int wsp_offset = jcp.typesize_acc
741 * (osb * jcp.nb_oc_blocking * jcp.max_width
742 * jcp.oc_block
743 + ocb * jcp.max_width * jcp.oc_block);
744 if (restore == write_tiles)
745 tilestored(ptr[reg_tilebuff + stride_seq + wsp_offset],
746 Tmm(get_out_tensor(osb, ocb)));
747 else if (restore == read_tiles)
748 tileloadd(Tmm(get_out_tensor(osb, ocb)),
749 ptr[reg_tilebuff + stride_seq + wsp_offset]);
750 }
751 };
752
753 auto reset_tiles = [=](int os_b, bool tail) {
754 if (jcp.nb_ic_int != 1) {
755 restore_output_tiles(os_b, write_tiles);
756 reconfig_tiles((tail) ? cfg_tiles_tail : cfg_tiles);
757 restore_output_tiles(os_b, read_tiles);
758 }
759 };
760
761 auto compute_icb_loop = [=](int os_b = 1) {
762 int shift = (get_ic_tail() && os_b == 1) ? 1 : 0;
763 int nb_ic_int = jcp.nb_ic_int - shift;
764
765 if (jcp.src_zero_point) {
766 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
767 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
768 vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
769 }
770 if (jcp.dst_zero_point) {
771 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
772 vcvtdq2ps(zmm_dst_zp,
773 EVEX_compress_addr(reg_dst_zero_point, 0, true));
774 }
775
776 for (int icb = 0; icb < nb_ic_int; icb++)
777 compute_block(icb, os_b);
778
779 // Tail processing
780 if (get_ic_tail() && os_b == 1) {
781 reset_tiles(os_b, true);
782 compute_block(nb_ic_int, os_b);
783 reset_tiles(os_b, false);
784 }
785 };
786
787 Label label_last_os, label_compute_done, label_tail, label_done;
788
789 int stride_nhwc_ = jcp.typesize_in * jcp.ngroups * jcp.ic_without_padding
790 * jcp.stride_w;
791 mov(stride_nhwc, stride_nhwc_);
792
793 prepare_output();
794 { // Compute
795 if (check_last_sb_) {
796 mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
797 cmp(reg_last_h, 1);
798 je(label_last_os, T_NEAR);
799 }
800 compute_icb_loop(jcp.nb_os_blocking);
801
802 jmp(label_compute_done, T_NEAR);
803
804 L(label_last_os);
805 compute_icb_loop();
806 }
807 L(label_compute_done);
808 { // Store
809 if (jcp.tile_tail && check_last_sb_)
810 store_output(do_store, true);
811 else
812 store_output(do_store, false);
813 }
814}
815
816void jit_avx512_core_amx_1x1_fwd_kernel_t::osb_loop(int nb_os) {
817 for (int osi = 0; osi < nb_os; osi++) {
818 bool do_store = IMPLICATION(jcp.per_one_pstore, (osi == nb_os - 1));
819 check_last_sb_ = do_store;
820
821 icb_loop(do_store);
822
823 int oh = (((osi + 1) * jcp.nb_os_blocking * jcp.tile_width) / jcp.ow);
824 int ow = (((osi + 1) * jcp.nb_os_blocking * jcp.tile_width) % jcp.ow);
825 if (do_store) {
826 size_t out_offset = jcp.typesize_out
827 * (oh * out_h_shift() + ow * out_w_shift());
828 add(out_ptr, out_offset);
829 }
830
831 int ih = oh * jcp.stride_h;
832 int iw = ow * jcp.stride_w;
833 add(inp_ptr, inp_offset(ih, iw, 0));
834 }
835}
836
837int jit_avx512_core_amx_1x1_fwd_kernel_t::get_ic_tail() const {
838 return (jcp.ic_without_padding % jcp.ic_block_int_np);
839}
840
841void jit_avx512_core_amx_1x1_fwd_kernel_t::generate() {
842 preamble();
843
844 last_oc_block_flag_ = (jcp.oc_without_padding != jcp.oc);
845 if (last_oc_block_flag_) {
846 Xbyak::Label mask_is_set;
847
848 // Use mask 0xF by default for all output data and post-ops
849 // loads / stores with block index
850 // ocb = occ * jcp.nb_oc_blocking + (jcp.nb_oc_blocking - 1)
851 // TODO: use masked loads / stores for the last occ only
852 int mask = (1 << jcp.oc_block) - 1;
853 Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
854 mov(regw_tmp, mask);
855 kmovw(ktail_mask, regw_tmp);
856 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
857 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
858 jne(mask_is_set, T_NEAR);
859
860 // Reset the mask
861 mask = (1 << (jcp.oc_without_padding % jcp.oc_block)) - 1;
862 mov(regw_tmp, mask);
863 kmovw(ktail_mask, regw_tmp);
864
865 L(mask_is_set);
866 }
867
868 mov(inp_ptr, ptr[param1 + GET_OFF(src)]);
869 mov(wei_ptr, ptr[param1 + GET_OFF(filt)]);
870 mov(out_ptr, ptr[param1 + GET_OFF(dst)]);
871 mov(wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
872
873 mov(reg_is_osb, ptr[param1 + GET_OFF(is_osb)]);
874
875 constexpr int tile_mem_stride_in_bytes = 64;
876 mov(stride_seq, tile_mem_stride_in_bytes);
877
878 init_runtime_counters();
879 update_buffer_pointers();
880
881 Xbyak::Label label_no_osb, label_done;
882
883 cmp(reg_is_osb, 0);
884 je(label_no_osb, T_NEAR);
885
886 osb_loop(jcp.nb_os2_blocking);
887 jmp(label_done, T_NEAR);
888
889 L(label_no_osb);
890 osb_loop();
891
892 L(label_done);
893 postamble();
894
895 if (jcp.with_eltwise) postops_injector_->prepare_table();
896}
897
898void jit_avx512_core_amx_1x1_fwd_kernel_t::tile_configure(char *tcfg_buff) {
899
900 int tile_max_columns_in_bytes
901 = amx::get_max_column_bytes(amx::get_target_palette());
902 const int max_palette_size_in_bytes = 64;
903
904 auto cfg_tiles = [=](palette_config_t *buff, int Ac) {
905 char *_tc = (char *)buff;
906 for (int i = 0; i < max_palette_size_in_bytes; i++)
907 _tc[i] = 0;
908
909 int Ar = jcp.tile_width;
910 int Br = Ac / jcp.typesize_acc;
911 int Cr = jcp.tile_width;
912
913 int Bc = tile_max_columns_in_bytes;
914 int Cc = tile_max_columns_in_bytes;
915
916 for (int s = 0; s < jcp.nb_os_blocking; s++)
917 tc_configure_tile(buff, get_inp_tensor(s), Ar, Ac);
918 for (int i = 0; i < jcp.nb_oc_blocking; i++)
919 tc_configure_tile(buff, get_wei_tensor(i), Br, Bc);
920
921 for (int s = 0; s < jcp.nb_os_blocking; s++)
922 for (int i = 0; i < jcp.nb_oc_blocking; i++) {
923 tc_configure_tile(buff, get_out_tensor(s, i), Cr, Cc);
924 }
925
926 buff->palette_id = amx::get_target_palette();
927 };
928
929 int Ac = jcp.typesize_in
930 * ((jcp.nb_ic_int == 1 && get_ic_tail()) ? get_ic_tail()
931 : jcp.ic_block_int_np);
932
933 cfg_tiles((palette_config_t *)tcfg_buff, Ac);
934 if (jcp.nb_ic_int > 1 && get_ic_tail()) {
935 int Ac = jcp.typesize_in * get_ic_tail();
936 char *_t = tcfg_buff + max_palette_size_in_bytes;
937 cfg_tiles((palette_config_t *)(_t), Ac);
938 }
939}
940
941status_t jit_avx512_core_amx_1x1_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp,
942 const convolution_desc_t &cd, memory_desc_t &src_md,
943 memory_desc_t &weights_md, memory_desc_t &dst_md,
944 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
945 using namespace prop_kind;
946
947 const memory_desc_wrapper src_d(&src_md);
948 const memory_desc_wrapper weights_d(&weights_md);
949 const memory_desc_wrapper dst_d(&dst_md);
950 const memory_desc_wrapper bias_d(&bias_md);
951
952 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
953 int ndims = src_d.ndims();
954 bool is_1d = ndims == 3;
955 bool is_3d = ndims == 5;
956
957 const bool is_bf16_convolution
958 = everyone_is(true, src_d.data_type() == data_type::bf16,
959 weights_d.data_type() == data_type::bf16,
960 one_of(dst_d.data_type(), data_type::bf16, data_type::f32));
961 const bool is_int8_convolution = everyone_is(true,
962 (src_d.data_type() == data_type::u8
963 || src_d.data_type() == data_type::s8),
964 weights_d.data_type() == data_type::s8,
965 one_of(dst_d.data_type(), data_type::f32, data_type::s32,
966 data_type::s8, data_type::u8, data_type::bf16));
967
968 bool supported = mayiuse(avx512_core_amx)
969 && (is_bf16_convolution || is_int8_convolution);
970 if (!supported) return status::unimplemented;
971
972 jcp = zero<decltype(jcp)>();
973 jcp.isa = avx512_core_amx;
974 jcp.ndims = ndims;
975 jcp.prop_kind = cd.prop_kind;
976 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
977 jcp.mb = src_d.dims()[0];
978 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
979 jcp.oc_without_padding = jcp.oc;
980 jcp.ic = src_d.dims()[1] / jcp.ngroups;
981 jcp.ic_without_padding = jcp.ic;
982 jcp.id = is_3d ? src_d.dims()[2] : 1;
983 jcp.ih = !is_1d ? src_d.dims()[ndims - 2] : 1;
984 jcp.iw = src_d.dims()[ndims - 1];
985 jcp.od = is_3d ? dst_d.dims()[2] : 1;
986 jcp.oh = !is_1d ? dst_d.dims()[ndims - 2] : 1;
987 jcp.ow = dst_d.dims()[ndims - 1];
988 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
989 jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
990 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
991 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
992 jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
993 jcp.l_pad = cd.padding[0][ndims - 3];
994 jcp.stride_d = is_3d ? cd.strides[0] : 1;
995 jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
996 jcp.stride_w = cd.strides[ndims - 3];
997 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
998
999 if (!(jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1))
1000 return status::unimplemented;
1001
1002 if (!(jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0))
1003 return status::unimplemented;
1004
1005 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
1006 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
1007 jcp.dilate_w = cd.dilates[ndims - 3];
1008
1009 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
1010
1011 if (jcp.dilate_d != 0 || jcp.dilate_h != 0 || jcp.dilate_w != 0)
1012 return status::unimplemented;
1013 if (jcp.is_depthwise)
1014 return status::unimplemented; // TODO: add support of DW convolution
1015 if (jcp.ngroups > 1)
1016 return status::unimplemented; // TODO: add support for non-unit groups
1017
1018 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1019 jcp.dst_dt = cd.dst_desc.data_type;
1020 jcp.src_dt = cd.src_desc.data_type;
1021 jcp.wei_dt = cd.weights_desc.data_type;
1022
1023 // Dispatch small shapes to VNNI for better performance
1024 const auto is_small_shape = jcp.od * jcp.oh * jcp.ow <= 4 && jcp.ic <= 512
1025 && jcp.mb * jcp.ngroups * jcp.ic * jcp.oc <= static_cast<int32_t>(
1026 platform::get_per_core_cache_size(1) / 2);
1027 const auto is_3d_small_ic = jcp.ndims == 5 && jcp.ic * jcp.oc <= 32
1028 && jcp.od >= 128 && jcp.oh >= 128 && jcp.ow >= 128;
1029 if (is_small_shape || is_3d_small_ic) return status::unimplemented;
1030
1031 const auto zp = attr.zero_points_;
1032 jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
1033 jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
1034 jcp.zp_src_is_common = zp.common(
1035 DNNL_ARG_SRC); // otherwise, it's per-channel (not supported)
1036 if (!IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)
1037 || !IMPLICATION(jcp.dst_zero_point || jcp.src_zero_point,
1038 is_int8_convolution))
1039 return status::unimplemented;
1040
1041 jcp.nthr = nthreads;
1042
1043 jcp.ic_block = 16;
1044 jcp.ic_block_int = is_bf16_convolution ? 32 : 64;
1045 jcp.ic_block_int_np = jcp.ic_block_int;
1046 if (jcp.ic_block_int < jcp.ic_without_padding
1047 && jcp.ic_without_padding % jcp.ic_block_int != 0) {
1048 // Order of blocks comes from empirical observation
1049 static const int try_blocks[] = {32, 48, 40, 56};
1050 for (auto blk_size : try_blocks) {
1051 const int _blk_size = is_bf16_convolution ? blk_size / 2 : blk_size;
1052 if (jcp.ic_without_padding % _blk_size == 0) {
1053 jcp.ic_block_int_np = _blk_size;
1054 break;
1055 }
1056 }
1057 }
1058 jcp.oc_block = 16;
1059
1060 bool args_ok = true && jcp.ic % 4 == 0
1061 && (jcp.ow == jcp.iw && jcp.stride_w == 1)
1062 && (jcp.oh == jcp.ih && jcp.stride_h == 1)
1063 && (jcp.od == jcp.id && jcp.stride_d == 1);
1064 if (!args_ok) return status::unimplemented;
1065
1066 if (jcp.ngroups == 1) {
1067 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1068 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1069 }
1070
1071 auto set_or_check_wei_format = [&]() {
1072 using namespace format_tag;
1073 using namespace memory_extra_flags;
1074 format_tag_t wei_tag;
1075 wei_tag = (is_bf16_convolution)
1076 ? pick(with_groups + 2 * (ndims - 3), OIw16i16o2i, gOIw16i16o2i,
1077 OIhw16i16o2i, gOIhw16i16o2i, OIdhw16i16o2i,
1078 gOIdhw16i16o2i)
1079 : pick(with_groups + 2 * (ndims - 3), OIw16i16o4i, gOIw16i16o4i,
1080 OIhw16i16o4i, gOIhw16i16o4i, OIdhw16i16o4i,
1081 gOIdhw16i16o4i);
1082 memory_desc_t want_wei_md = weights_md;
1083 memory_desc_init_by_tag(want_wei_md, wei_tag);
1084
1085 if (jcp.src_zero_point) {
1086 want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
1087 want_wei_md.extra.asymm_compensation_mask = (1 << 0)
1088 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
1089 }
1090 if (weights_md.format_kind == format_kind::any) {
1091 weights_md = want_wei_md;
1092 return true;
1093 }
1094 return weights_md == want_wei_md;
1095 };
1096
1097 if (!set_or_check_wei_format()) { return status::unimplemented; }
1098
1099 format_tag_t dat_tag = utils::pick(
1100 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1101
1102 if (src_d.format_kind() == format_kind::any) {
1103 CHECK(memory_desc_init_by_tag(src_md, dat_tag));
1104 jcp.src_tag = dat_tag;
1105 } else {
1106 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
1107 }
1108 if (jcp.src_tag != dat_tag) { return status::unimplemented; }
1109
1110 if (dst_d.format_kind() == format_kind::any) {
1111 CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
1112 jcp.dst_tag = dat_tag;
1113 } else {
1114 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
1115 }
1116 if (jcp.dst_tag != dat_tag) { return status::unimplemented; }
1117
1118 if (jcp.with_bias) {
1119 if (bias_d.format_kind() == format_kind::any)
1120 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
1121 }
1122
1123 CHECK(attr.set_default_formats(&dst_md));
1124
1125 const auto &p = attr.post_ops_;
1126
1127 const int sum_ind = p.find(primitive_kind::sum);
1128 jcp.with_sum = sum_ind != -1;
1129 const int eltwise_ind = p.find(primitive_kind::eltwise);
1130 jcp.with_eltwise = eltwise_ind != -1;
1131 const int binary_ind = p.find(primitive_kind::binary);
1132 jcp.with_binary = binary_ind != -1;
1133 jcp.sum_dt = p.get_sum_dt(jcp.dst_dt);
1134
1135 jcp.post_ops = p;
1136 jcp.is_fast_postops = is_fast_postops(jcp);
1137
1138 using namespace injector;
1139 const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16);
1140 const bool sum_requires_scale_one = sum_at_pos_0_only;
1141 const bool sum_requires_zp_zero = sum_at_pos_0_only;
1142 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
1143 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
1144 sum_requires_zp_zero});
1145 if (!post_ops_ok_) return status::unimplemented;
1146
1147 jcp.typesize_in = types::data_type_size(src_d.data_type());
1148 jcp.typesize_out = types::data_type_size(dst_d.data_type());
1149 jcp.typesize_bia
1150 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
1151 jcp.typesize_acc = sizeof(int32_t);
1152
1153 jcp.nb_ic = jcp.ic / jcp.ic_block;
1154 jcp.nb_oc = jcp.oc / jcp.oc_block;
1155 jcp.nb_ic_int = div_up(jcp.ic_without_padding, jcp.ic_block_int_np);
1156
1157 jcp.max_width = amx::get_max_rows(amx::get_target_palette());
1158 if (jcp.max_width <= 0) return status::unimplemented;
1159
1160 const int size_treshold = 32;
1161 const int min_width
1162 = 1; // TODO: Possible optimizations: do not use small values
1163 const int spatial = jcp.od * jcp.oh;
1164 const int os = jcp.od * jcp.oh * jcp.ow;
1165
1166 jcp.tile_width = 1;
1167 for (int s_size = jcp.max_width; s_size >= min_width; s_size--) {
1168 if ((spatial >= size_treshold && spatial % s_size == 0)
1169 || (spatial < size_treshold && os % s_size == 0)) {
1170 jcp.tile_width = s_size;
1171 break;
1172 }
1173 }
1174 if (jcp.tile_width == 1) {
1175 jcp.tile_width = nstl::min(jcp.max_width, os);
1176 jcp.tile_tail = os % jcp.max_width;
1177 for (int i = jcp.max_width; i >= min_width; i--) {
1178 int i_tail = os % i;
1179 if (i_tail > jcp.tile_tail || i_tail == 0) {
1180 jcp.tile_width = i;
1181 jcp.tile_tail = i_tail;
1182 if (i_tail == 0) break;
1183 }
1184 }
1185 if (jcp.tile_width < min_width && jcp.tile_tail < min_width)
1186 jcp.tile_tail = 0;
1187 }
1188
1189 /* TODO: Add stride support !
1190 while ((jcp.stride_h != 1 || jcp.stride_w != 1)
1191 && (jcp.ow % jcp.tile_width != 0) || jcp.tile_width > 16) {
1192 jcp.tile_width = jcp.ow / 2;
1193 }
1194 */
1195
1196 // TODO: Add support for spatial tails
1197 if (jcp.tile_tail != 0) return status::unimplemented;
1198
1199 // TODO: Implement efficient tile tail processing. Now just go to common
1200 // case if we utilize half of tile or less.
1201 if (jcp.tile_width <= jcp.max_width / 2) return status::unimplemented;
1202
1203 jcp.nb_oc_blocking = (jcp.nb_oc % 2 == 0) ? 2 : 1;
1204 jcp.nb_ic_blocking = 1;
1205 jcp.nb_os_blocking = (os / jcp.tile_width > 2) ? 2 : 1;
1206 jcp.nb_os2_blocking = (jcp.nb_os_blocking > 1)
1207 ? ((jcp.nb_os_blocking * jcp.tile_width) % 2 == 0) ? 2 : 1
1208 : 1;
1209 jcp.nb_os = os / jcp.tile_width;
1210
1211 jcp.wsp_buffer_size = (size_t)2 * jcp.nb_os_blocking * jcp.nb_oc_blocking
1212 * jcp.max_width * jcp.oc_block;
1213
1214 int ops_tile_store
1215 = jcp.nb_oc_blocking * jcp.nb_os_blocking * jcp.tile_width;
1216 int avaliable_ops = jcp.nb_ic_int * jcp.nb_oc_blocking * jcp.nb_os_blocking;
1217 jcp.per_one_pstore
1218 = (avaliable_ops) ? ops_tile_store / avaliable_ops + 1 : 0;
1219 if (jcp.per_one_pstore > 12) jcp.per_one_pstore = 0;
1220
1221 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
1222 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
1223 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
1224 const int wei_mask_per_oc = 1 << (int)with_groups;
1225 jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc;
1226 jcp.dst_scale = !dst_scales.has_default_values();
1227
1228 // only common src & dst scales are supported
1229 // only common and per-oc-channel weight scales are supported
1230 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc)
1231 && everyone_is(src_scales.mask_, dst_scales.mask_, 0);
1232 if (!scales_ok) return status::unimplemented;
1233
1234 return status::success;
1235}
1236
1237void jit_avx512_core_amx_1x1_fwd_kernel_t::init_scratchpad(
1238 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
1239 const primitive_attr_t &attr) {
1240 scratchpad.book(key_conv_amx_wsp_buffer, jcp.nthr * jcp.wsp_buffer_size,
1241 jcp.typesize_acc);
1242 if (jcp.ic_without_padding % jcp.ic_block_int_np)
1243 scratchpad.book(key_conv_amx_tile_buffer,
1244 jcp.nthr * (jcp.wsp_buffer_size / 2), jcp.typesize_acc);
1245 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) {
1246 assert(jcp.ngroups == 1);
1247 scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
1248 }
1249 scratchpad.book(key_conv_amx_tilecfg, 2, 64); // 2 whole cachelines
1250 book_precomputed_scales(
1251 scratchpad, attr.scales_, jcp.ngroups * jcp.oc_without_padding);
1252}
1253
1254} // namespace x64
1255} // namespace cpu
1256} // namespace impl
1257} // namespace dnnl
1258
1259// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
1260