1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/memory_tracking.hpp"
20#include "common/nstl.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/platform.hpp"
25#include "cpu/scale_utils.hpp"
26#include "cpu/x64/cpu_barrier.hpp"
27#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
29#include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp"
30
31#define GET_OFF(field) offsetof(jit_conv_call_s, field)
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36namespace x64 {
37
38using namespace dnnl::impl::memory_tracking::names;
39using namespace dnnl::impl::data_type;
40using namespace dnnl::impl::utils;
41using namespace Xbyak;
42
43void jit_avx512_core_amx_compute_zp_pbuff_t::prepare_output(int ur_w) {
44 for (int oc = 0; oc < jcp.nb_oc_blocking; oc++)
45 for (int ur = 0; ur < ur_w; ur++) {
46 const Zmm zmm = zmm_out(ur, oc);
47 vpxord(zmm, zmm, zmm);
48 }
49}
50
51void jit_avx512_core_amx_compute_zp_pbuff_t::store_output(
52 int ur_w, bool last_oc_block_flag) {
53 assert(jcp.is_nspc);
54
55 const int nb_oc_block = jcp.nb_oc_blocking;
56 const int oc_block = jcp.oc_block;
57
58 const auto src_zp_addr = EVEX_compress_addr(reg_src_zero_point, 0, true);
59
60 /* write out register to output_addr */
61 for (int oc = 0; oc < nb_oc_block; oc++) {
62 const bool mask_flag = last_oc_block_flag && oc == nb_oc_block - 1;
63 for (int ur = 0; ur < ur_w; ur++) {
64 const int output_offset = sizeof(int32_t)
65 * (oc * oc_block
66 + ur * jcp.oc_without_padding * jcp.ngroups);
67 const Zmm zmm_dst = zmm_out(ur, oc);
68 const Zmm m_zmm_dst = mask_flag ? zmm_dst | ktail_mask : zmm_dst;
69 // multiply dst by src_zero_point
70 vpmulld(m_zmm_dst, zmm_dst, src_zp_addr);
71 vmovups(EVEX_compress_addr(reg_zp_pbuff, output_offset), m_zmm_dst);
72 }
73 }
74}
75
76void jit_avx512_core_amx_compute_zp_pbuff_t::compute_ker(int ur_w, int pad_l,
77 int pad_r, ic_block_t last_ic_block_flag, bool padded) {
78
79 const int kw = jcp.kw;
80 const int ic_block = jcp.ic_block_int_np;
81 const int oc_block = jcp.oc_block;
82 const int nb_oc_block = jcp.nb_oc_blocking;
83
84 const bool ic_tail
85 = (jcp.ic_without_padding % (jcp.ic_block / ic_inner_block)) > 0;
86 const bool masked_write = ic_tail && last_ic_block_flag == last_ic_block;
87
88 /* Skip the last loads of input
89 if (ic%16)/ic_sub_step < ic_block/ic_sub_step */
90 const int icb = (last_ic_block_flag == last_ic_block)
91 ? div_up(
92 (jcp.ic_without_padding % jcp.ic_block_int), ic_inner_block)
93 : ic_block / ic_inner_block;
94
95 auto get_filter_offset = [=](int ocb, int ic, int ki) {
96 size_t w_step = jcp.is_relo ? jcp.kh : 1;
97 size_t kw_offset = static_cast<size_t>(ki) * w_step
98 * jcp.ic_block_int_np * jcp.oc_block;
99 size_t oc_subblock_step = static_cast<size_t>(jcp.kd) * jcp.kh * jcp.kw
100 * jcp.ic_block_int_np * jcp.oc_block;
101 size_t offset = kw_offset
102 + static_cast<size_t>(ocb) * jcp.nb_ic_int * oc_subblock_step
103 + static_cast<size_t>(ic) * oc_block * ic_inner_block;
104 return sizeof(char) * offset;
105 };
106 auto compute_fma = [=](const Zmm zmm_accum, const int ic,
107 const Address addr) {
108 if (jcp.is_relo) {
109 vmovups(zmm_permb, ptr[reg_scratch]); // get permute index table
110 const Zmm r_zmm = masked_write && ic == icb - 1
111 ? zmm_permb | kmask_ic_block | T_z
112 : zmm_permb;
113 // only values from 'src2' are used to write dst
114 vpermi2b(r_zmm, zmm_permb, addr);
115 vpdpbusd(zmm_accum, zmm_one,
116 zmm_permb); // XXX - using the same register for all ur_w
117 } else {
118 vpdpbusd(zmm_accum, zmm_one, addr);
119 }
120 };
121
122 if (jcp.is_relo && last_ic_block_flag == last_ic_block && ic_tail) {
123 const Reg64 reg_tmp = reg_scratch;
124 mov(reg_tmp, ic_mask_label);
125 kmovq(kmask_ic_block, qword[reg_tmp]);
126 }
127 if (jcp.is_relo) mov(reg_scratch, permb_idx_label);
128
129 for (int ki = 0; ki < kw; ki++) {
130 const int ur_start = get_ow_start(ki, pad_l);
131 const int ur_end = get_ow_end(ur_w, ki, pad_r);
132 for (int ur = 0; ur < ur_w; ur++) {
133 // Calculate zero_point padding as:
134 // accum = is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0)
135 if (ur < ur_start || ur >= ur_end || padded) {
136 for (int oc = 0; oc < nb_oc_block; oc++) {
137 const Zmm zmm_accum = zmm_out(ur, oc);
138 for (int ic = 0; ic < icb; ic++) {
139 const auto addr_filt = EVEX_compress_addr(
140 aux_reg_filt, get_filter_offset(oc, ic, ki));
141 compute_fma(zmm_accum, ic, addr_filt);
142 }
143 }
144 }
145 }
146 }
147}
148
149void jit_avx512_core_amx_compute_zp_pbuff_t::kh_loop(int ur_w, int pad_l,
150 int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
151
152 Label kh_label, skip_kh_loop;
153 const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
154 const size_t shift_wei_h_step = sizeof(char)
155 * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
156 * jcp.oc_block;
157
158 // Compute zero_point compensation for the padded region. Total compute
159 // area is 'overflow * kw' where 'overflow' indicates the overlap
160 // between the filter and either top_pad or bottom_pad region.
161 auto compute_kh_loop = [=](size_t param_overflow) {
162 Label overflow_label, no_overflow_label;
163
164 mov(reg_overflow, ptr[param1 + param_overflow]);
165 cmp(reg_overflow, 0);
166 je(no_overflow_label, T_NEAR);
167 L(overflow_label);
168 {
169 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
170 add(aux_reg_filt, shift_wei_h_step);
171 dec(reg_overflow);
172 jne(overflow_label, T_NEAR);
173 }
174 L(no_overflow_label);
175 };
176
177 if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(t_overflow));
178
179 // check for holes and skip computation due to dilation
180 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
181 if ((jcp.dilate_h >= jcp.ih)) {
182 cmp(reg_kj, 0);
183 je(skip_kh_loop, T_NEAR);
184 }
185
186 L(kh_label);
187 {
188 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
189
190 add(aux_reg_filt, shift_wei_h_step);
191 dec(reg_kj);
192 jne(kh_label, T_NEAR);
193 }
194
195 L(skip_kh_loop);
196
197 if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(b_overflow));
198}
199
200void jit_avx512_core_amx_compute_zp_pbuff_t::kd_loop(int ur_w, int pad_l,
201 int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
202
203 Label kd_label, skip_kd_loop;
204 const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
205 const size_t shift_wei_h_step = sizeof(char)
206 * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
207 * jcp.oc_block;
208
209 // Compute zero_point compensation for the padded region. Total compute
210 // area is 'overflow * kh * kw' where 'overflow' indicates the overlap
211 // between the filter and either front_pad or back_pad region.
212 auto compute_kd_loop = [=](size_t param_overflow) {
213 Label kh_loop_label;
214 Label no_overflow_label, overflow_label;
215
216 mov(reg_ki, ptr[param1 + param_overflow]);
217 cmp(reg_ki, 0);
218 je(no_overflow_label, T_NEAR);
219 L(overflow_label);
220 {
221 mov(aux_reg_filt, aux_reg_filt_d);
222 mov(reg_kj, jcp.kh);
223 L(kh_loop_label);
224 {
225 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
226 add(aux_reg_filt, shift_wei_h_step);
227 dec(reg_kj);
228 jne(kh_loop_label, T_NEAR);
229 }
230 add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
231 dec(reg_ki);
232 jne(overflow_label, T_NEAR);
233 }
234 L(no_overflow_label);
235 };
236
237 const bool zp_d_padding
238 = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
239 if (zp_d_padding) {
240 mov(aux_reg_filt_d, reg_filt);
241 compute_kd_loop(GET_OFF(f_overflow));
242
243 // check for holes and skip computation due to dilation
244 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
245 if (jcp.dilate_d >= jcp.id) {
246 cmp(reg_ki, 0);
247 je(skip_kd_loop, T_NEAR);
248 }
249 L(kd_label);
250 mov(aux_reg_filt, aux_reg_filt_d);
251
252 } else {
253 mov(aux_reg_filt, reg_filt);
254 }
255
256 kh_loop(ur_w, pad_l, pad_r, last_ic_block_flag, handle_h_pad);
257
258 if (zp_d_padding) {
259 add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
260 dec(reg_ki);
261 jne(kd_label, T_NEAR);
262
263 L(skip_kd_loop);
264
265 compute_kd_loop(GET_OFF(back_overflow));
266 }
267}
268
269void jit_avx512_core_amx_compute_zp_pbuff_t::icb_loop(
270 int ur_w, int pad_l, int pad_r, bool handle_h_pad) {
271
272 Label icb_label;
273 const size_t nb_ic = jcp.nb_ic_int;
274 const bool do_icb_loop = nb_ic > 1;
275
276 /* Initialize zmm_one for weight accumulation */
277 xor_(reg_scratch, reg_scratch);
278 const Reg8 _t8 = reg_scratch.cvt8();
279 mov(_t8, 0x1);
280 vpbroadcastb(zmm_one, _t8);
281
282 prepare_output(ur_w);
283
284 mov(reg_icb, nb_ic);
285
286 L(icb_label);
287 if (jcp.ic_without_padding != jcp.ic) {
288 Label common_ker, end_ker;
289 if (do_icb_loop) {
290 cmp(reg_icb, 1); // The last ic block
291 jne(common_ker, T_NEAR);
292 }
293 kd_loop(ur_w, pad_l, pad_r, last_ic_block, handle_h_pad);
294 if (do_icb_loop) {
295 jmp(end_ker, T_NEAR);
296
297 L(common_ker);
298 kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
299
300 L(end_ker);
301 }
302 } else {
303 kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
304 }
305 // End of IC Loop
306 if (do_icb_loop) {
307 const size_t shift_wei_icb_step = static_cast<size_t>(jcp.kd) * jcp.kh
308 * jcp.kw * jcp.oc_block * jcp.ic_block_int_np;
309 add(reg_filt, sizeof(char) * shift_wei_icb_step);
310
311 dec(reg_icb);
312 cmp(reg_icb, 0);
313 jg(icb_label, T_NEAR);
314
315 sub(reg_filt, sizeof(char) * shift_wei_icb_step * nb_ic);
316 }
317
318 if (jcp.oc_without_padding != jcp.oc) {
319 Label common_store, end_store;
320
321 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
322 jne(common_store, T_NEAR);
323
324 store_output(ur_w, true); // last oc block
325 jmp(end_store, T_NEAR);
326
327 L(common_store);
328 store_output(ur_w, false);
329
330 L(end_store);
331 } else {
332 store_output(ur_w, false);
333 }
334}
335
336void jit_avx512_core_amx_compute_zp_pbuff_t::unroll_width(
337 const bool h_padding) {
338
339 auto ur_w_shift = [&](const int ur_w) {
340 return sizeof(int32_t) * (ur_w * jcp.oc_without_padding * jcp.ngroups);
341 };
342
343 const int max_ur_w = jit_avx512_core_amx_compute_zp_pbuff_t::max_regs_ur
344 / (jcp.nb_oc_blocking);
345 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
346 int l_pad = jcp.l_pad;
347
348 const int l_pad_output = jcp.l_pad_output;
349 const int r_pad_output = jcp.r_pad_output;
350
351 // a single middle element (if required) containing only height padding
352 const int no_pad = nstl::max(0, jcp.ow - l_pad_output - r_pad_output);
353
354 const int ow_start = nstl::max(jcp.ow - r_pad_output, l_pad_output);
355 const int r_pad_start = nstl::min(jcp.ow_pad - l_pad_output, r_pad_output);
356
357 int ow = 0;
358 int cur_l_pad_output = l_pad_output;
359 while (cur_l_pad_output > 0) {
360 const int ur_w = nstl::min(cur_l_pad_output, max_ur_w);
361 ow += ur_w;
362 const int cur_r_pad = calculate_end_padding(
363 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
364 icb_loop(ur_w, l_pad, cur_r_pad, h_padding);
365 add(reg_zp_pbuff, ur_w_shift(ur_w));
366
367 l_pad = nstl::max(l_pad - ur_w * jcp.stride_w, 0);
368 cur_l_pad_output = nstl::max(cur_l_pad_output - ur_w, 0);
369 }
370
371 if (no_pad > 0) {
372 const int ur_w = 1;
373 if (h_padding) icb_loop(ur_w, 0, 0, true);
374 if (h_padding || jcp.ow_mid) add(reg_zp_pbuff, ur_w_shift(ur_w));
375 }
376 assert(ow + no_pad == ow_start);
377
378 ow = ow_start;
379 int cur_r_pad_output = r_pad_start;
380 while (cur_r_pad_output > 0 && ow < jcp.ow) {
381 const int ur_w = nstl::min(cur_r_pad_output, max_ur_w);
382 ow += ur_w;
383 const int cur_r_pad = calculate_end_padding(
384 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
385 icb_loop(ur_w, 0, cur_r_pad, h_padding);
386 add(reg_zp_pbuff, ur_w_shift(ur_w));
387
388 cur_r_pad_output = nstl::max(cur_r_pad_output - ur_w, 0);
389 }
390}
391
392void jit_avx512_core_amx_compute_zp_pbuff_t::generate() {
393 Label h_pad_label, end_label;
394
395 assert(jcp.req_zero_point_buffer);
396 assert(jcp.typesize_in == sizeof(char));
397
398 preamble();
399
400 mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
401 mov(reg_zp_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
402 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
403
404 if (jcp.oc_without_padding != jcp.oc) {
405 const Reg32 reg_tmp = reg_scratch.cvt32();
406 const int tail_size = jcp.oc_without_padding % jcp.oc_block;
407 const int mask = (1 << tail_size) - 1;
408 mov(reg_tmp, mask);
409 kmovw(ktail_mask, reg_tmp);
410 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
411 }
412
413 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
414 cmp(reg_overflow, 0);
415 jne(h_pad_label, T_NEAR);
416 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
417 cmp(reg_overflow, 0);
418 jne(h_pad_label, T_NEAR);
419 if (jcp.ndims == 5 && (jcp.f_pad_output > 0 || jcp.back_pad_output > 0)) {
420 mov(reg_overflow, ptr[param1 + GET_OFF(kd_padding)]);
421 cmp(reg_overflow, jcp.kd);
422 jne(h_pad_label, T_NEAR);
423 }
424
425 // Handle width padding region
426 unroll_width(false);
427 jmp(end_label, T_NEAR);
428
429 // handle height padding region
430 L(h_pad_label);
431 unroll_width(true);
432
433 L(end_label);
434
435 postamble();
436
437 // reduced-lowering ('is_relo' == true) weights format is '..i16o', so
438 // permute elements through permb into the VNNI layout '...16i4i'.
439 if (jcp.is_relo) {
440 align(64);
441 L(permb_idx_label);
442 // permb: id-bit for table selection is bit[6]
443 const uint8_t select_src2_bit = 0x40;
444 // permb: bits [5:0] select the element within each input table
445 const uint8_t permb_idx_table[64] = {0, 16, 32, 48, 1, 17, 33, 49, 2,
446 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22,
447 38, 54, 7, 23, 39, 55, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42,
448 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46,
449 62, 15, 31, 47, 63};
450 for (size_t i = 0; i < 64; ++i)
451 db(select_src2_bit | permb_idx_table[i]);
452
453 // write zero-mask (permb) for ic_tail in VNNI format '..16o4i'
454 const int ic_tail_size
455 = jcp.ic_without_padding % (jcp.ic_block / ic_inner_block);
456 if (jcp.ic_without_padding != jcp.ic && ic_tail_size > 0) {
457 align(64);
458 L(ic_mask_label);
459
460 assert(4 > ic_tail_size);
461 // mask is on a 4-bit basis from the 4 ic elements in a zmm
462 const int nibble = (1 << ic_tail_size) - 1;
463 for (int i = 0; i < 16; ++i) {
464 db(nibble | (nibble << 4));
465 }
466 }
467 }
468}
469
470void jit_avx512_core_amx_copy_to_wbuffer_t::generate() {
471
472 const bool is_bf16 = jcp.src_dt == data_type::bf16;
473
474 // required for use of VPERMB instruction
475 assert(IMPLICATION(!is_bf16, cpu().has(Xbyak::util::Cpu::tAVX512_VBMI)));
476 assert(jcp.ic_block_int * jcp.typesize_in == 64);
477
478 preamble();
479
480 mov(reg_src, ptr[param1 + GET_OFF(src)]);
481 mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
482
483 // load permute indices from data section
484 Label permute_index_table;
485 mov(reg_tmp, permute_index_table);
486 if (is_bf16)
487 vmovdqu16(zmm_idx, ptr[reg_tmp]);
488 else
489 vmovdqu8(zmm_idx, ptr[reg_tmp]);
490
491 const int vnni_width = is_bf16 ? 2 : 4;
492 const int r = jcp.kh * jcp.kw * jcp.ic_without_padding;
493 const int nb_r = div_up(r, vnni_width);
494 const int rtail = (r % vnni_width) * jcp.oc_block;
495 if (rtail > 0) {
496 uint64_t mask = (UINT64_C(1) << rtail) - 1;
497 mov(reg_tmp, mask);
498 kmovq(kmask_load, reg_tmp);
499 }
500 const int nb_z = rnd_up(nb_r, jcp.ic_block);
501 if (nb_r < nb_z) vpxord(zmm_zero, zmm_zero, zmm_zero);
502
503 const int tile_size = jcp.ic_block_int * jcp.oc_block * jcp.typesize_in;
504 const int ocb_src_step = r * jcp.oc_block * jcp.typesize_in;
505 const int ocb_dst_step = rnd_up(ocb_src_step, tile_size);
506
507 // reorder from ~Owhi16o -> ~OR16oVr with r := whi and V := vnni_width
508 for (int g = 0; g < jcp.ngroups; g++) {
509 for (int ocb = 0; ocb < jcp.nb_oc; ocb++) {
510 int offset = 0;
511 int rb = 0;
512 for (; rb < nb_r; offset += 64, rb++) {
513 auto zmm_src_tmp = (rtail > 0 && rb == nb_r - 1)
514 ? zmm_src | kmask_load | T_z
515 : zmm_src;
516 if (is_bf16) {
517 vmovdqu16(zmm_src_tmp, ptr[reg_src + offset]);
518 vpermw(zmm_dst, zmm_idx, zmm_src);
519 vmovdqu16(ptr[reg_dst + offset], zmm_dst);
520 } else {
521 vmovdqu8(zmm_src_tmp, ptr[reg_src + offset]);
522 vpermb(zmm_dst, zmm_idx, zmm_src);
523 vmovdqu8(ptr[reg_dst + offset], zmm_dst);
524 }
525 }
526 for (; rb < nb_z; offset += 64, rb++) {
527 if (is_bf16)
528 vmovdqu16(ptr[reg_dst + offset], zmm_zero);
529 else
530 vmovdqu8(ptr[reg_dst + offset], zmm_zero);
531 }
532 add(reg_src, ocb_src_step);
533 add(reg_dst, ocb_dst_step);
534 }
535 }
536
537 postamble();
538
539 align(64);
540 L(permute_index_table);
541 const uint8_t no = 16; // 16o
542 const uint8_t nr = is_bf16 ? 2 : 4; // 2r or 4r
543 for (uint8_t o = 0; o < no; ++o) {
544 for (uint8_t r = 0; r < nr; r++) {
545 const uint8_t index = o + r * no;
546 if (is_bf16)
547 dw(index);
548 else
549 db(index);
550 }
551 }
552}
553
554void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_body(
555 int lpad, int iw_len, int icb) {
556
557 const bool is_bf16 = jcp.src_dt == data_type::bf16;
558 int iwp_idx = 0;
559 // there are min(gen_kw, jcp.stride_w) continuous sets of input
560 // data (for each stride idx), they are placed one by one
561 // without additional padding
562 const bool are_sets_interleaved
563 = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
564 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
565 const int num_sets = are_sets_interleaved ? jcp.n_stride_sets : jcp.kw;
566 for (int set_idx = 0; set_idx < num_sets; set_idx++) {
567 int set_width_padded = !jcp.is_pbuffer_strided
568 ? (jcp.ow_block - 1) * jcp.stride_w + gen_kw
569 : are_sets_interleaved ? jcp.ow_block - 1 + gen_kw / num_sets
570 + (set_idx < gen_kw % num_sets ? 1 : 0)
571 : jcp.ow_block;
572 for (int set_shift = 0; set_shift < set_width_padded;
573 set_shift++, iwp_idx++) {
574 int iw_idx = set_idx * (jcp.dilate_w + 1)
575 + set_shift * (jcp.is_pbuffer_strided ? jcp.stride_w : 1)
576 - lpad;
577 size_t out_base_offset
578 = (size_t)jcp.typesize_in * iwp_idx * jcp.ic_block_int_np;
579 if (iw_idx < 0 || iw_idx >= iw_len) {
580 // left or right padding
581 vmovups(ptr[reg_aux_out_ptr + out_base_offset], zmm_zero);
582 } else if (jcp.is_nspc) {
583 size_t inp_w_offset = (size_t)jcp.typesize_in * iw_idx
584 * jcp.ngroups * jcp.ic_without_padding;
585 int ic = icb * jcp.ic_block_int_np;
586 // TODO: use Xmm or Ymm moves for better small ic efficiency
587 auto zmm_tmp_mask
588 = ic + jcp.ic_block_int <= jcp.ic_without_padding
589 ? zmm_tmp
590 : zmm_tmp | ktail_mask | T_z;
591 if (is_bf16) {
592 vmovdqu16(
593 zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
594 vmovdqu16(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
595 } else {
596 vmovdqu8(zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
597 vmovdqu8(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
598 }
599 } else {
600 assert(is_bf16);
601 size_t inp_w_offset
602 = (size_t)jcp.typesize_in * iw_idx * jcp.ic_block;
603 for (int j = 0; j < jcp.ic_block_int_np / jcp.ic_block; j++) {
604 int ic = icb * jcp.ic_block_int_np + j * jcp.ic_block;
605 size_t inp_c_w_offset = (size_t)jcp.typesize_in * j * jcp.ih
606 * jcp.iw * jcp.ic_block
607 + inp_w_offset;
608 if (ic + jcp.ic_block <= jcp.ic) {
609 vmovdqu16(
610 ymm_tmp, ptr[reg_aux_inp_ptr + inp_c_w_offset]);
611 } else {
612 vpxord(ymm_tmp, ymm_tmp, ymm_tmp);
613 }
614 size_t out_offset = out_base_offset
615 + (size_t)jcp.typesize_in * j * jcp.ic_block;
616 vmovdqu16(ptr[reg_aux_out_ptr + out_offset], ymm_tmp);
617 }
618 }
619 }
620 }
621}
622
623void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row(int icb) {
624 if (jcp.nb_ow == 1) {
625 copy_row_body(jcp.l_pad, jcp.iw, icb);
626 } else {
627 auto get_iw_len_required = [&](int cur_ow_block, int cur_lpad) {
628 return (cur_ow_block - 1) * jcp.stride_w
629 + (jcp.kw - 1) * (jcp.dilate_w + 1) + 1 - cur_lpad;
630 };
631
632 auto get_iw_len_limited = [&](int owb, int cur_ow_block, int cur_lpad) {
633 auto len_req = get_iw_len_required(cur_ow_block, cur_lpad);
634 if (owb < 0) return len_req;
635 int ow_block_start = nstl::max(
636 0, owb * jcp.ow_block * jcp.stride_w - jcp.l_pad);
637 return nstl::min(jcp.iw - ow_block_start, len_req);
638 };
639
640 int general_owb_cases = jcp.nb_ow;
641 Xbyak::Label copy_row_done_label;
642 bool special_first_block_case = jcp.l_pad > 0;
643 if (special_first_block_case) {
644 general_owb_cases--;
645 Xbyak::Label skip_first_block_case_label;
646 cmp(reg_owb, 0);
647 jne(skip_first_block_case_label, T_NEAR);
648 copy_row_body(jcp.l_pad,
649 get_iw_len_limited(0, jcp.ow_block, jcp.l_pad), icb);
650 jmp(copy_row_done_label, T_NEAR);
651 L(skip_first_block_case_label);
652 }
653 bool special_last_block_case = false
654 // has ow_block_tail
655 || jcp.ow % jcp.ow_block != 0
656 // there is no ow_block_tail but right padding exists
657 || get_iw_len_limited(jcp.nb_ow - 1, jcp.ow_block, 0)
658 != get_iw_len_required(jcp.ow_block, 0);
659 if (special_last_block_case) {
660 general_owb_cases--;
661 Xbyak::Label skip_last_block_case_label;
662 cmp(reg_owb, jcp.nb_ow - 1);
663 jne(skip_last_block_case_label, T_NEAR);
664 int ow_block_tail = jcp.ow % jcp.ow_block;
665 int cur_ow_block = ow_block_tail > 0 ? ow_block_tail : jcp.ow_block;
666 copy_row_body(
667 0, get_iw_len_limited(jcp.nb_ow - 1, cur_ow_block, 0), icb);
668 jmp(copy_row_done_label, T_NEAR);
669 L(skip_last_block_case_label);
670 }
671
672 bool special_penult_block_case = true
673 // if nb_ow = 2 and l_pad > 0 it's the same as
674 // special_first_block_case
675 && jcp.nb_ow >= (special_first_block_case ? 3 : 2)
676 // right padding exists in penult block
677 && get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0)
678 != get_iw_len_required(jcp.ow_block, 0);
679 if (special_penult_block_case) {
680 general_owb_cases--;
681 Xbyak::Label skip_penult_block_case_label;
682 cmp(reg_owb, jcp.nb_ow - 2);
683 jne(skip_penult_block_case_label, T_NEAR);
684 copy_row_body(
685 0, get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0), icb);
686 jmp(copy_row_done_label, T_NEAR);
687 L(skip_penult_block_case_label);
688 }
689
690 if (general_owb_cases > 0) // general case
691 copy_row_body(0, get_iw_len_required(jcp.ow_block, 0), icb);
692
693 L(copy_row_done_label);
694 }
695}
696
697void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_reduced_lowering() {
698 assert(jcp.nb_ic_int == 1);
699 assert(jcp.ic_block_int * jcp.typesize_in == 64);
700 assert(jcp.is_nspc);
701
702 auto load_mask = [=](int tail, Opmask kmask) {
703 uint64_t mask = (UINT64_C(1) << tail) - 1;
704 mov(reg_tmp, mask);
705 kmovq(kmask, reg_tmp);
706 };
707
708 const bool is_bf16 = jcp.src_dt == data_type::bf16;
709 const int inp_w_step
710 = jcp.ngroups * jcp.ic_without_padding * jcp.typesize_in;
711 const int inp_h_step = jcp.iw * inp_w_step;
712 const int out_h_step = jcp.ic_without_padding * jcp.typesize_in;
713 const int out_w_step = jcp.kh * out_h_step;
714 const int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
715 if (tail_size > 0) load_mask(tail_size, ktail_mask);
716
717 auto zero_it = [=](reg64_t tmp_out_ptr) {
718 for (int ic = 0; ic < jcp.ic_without_padding; ic += jcp.ic_block_int) {
719 const int offset = ic * jcp.typesize_in;
720 const bool masked = ic + jcp.ic_block_int > jcp.ic_without_padding;
721 Zmm zmm = masked ? zmm_zero | ktail_mask : zmm_zero;
722 if (is_bf16)
723 vmovdqu16(ptr[tmp_out_ptr + offset], zmm);
724 else
725 vmovdqu8(ptr[tmp_out_ptr + offset], zmm);
726 }
727 };
728
729 // pointer to 1st needed element in src buffer
730 mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
731 // pointer to 1st needed element in dst buffer
732 mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
733
734 // total number of rows to copy
735 mov(reg_kht, ptr[param1 + GET_OFF(kh_offset)]);
736
737 // number of rows of src buffer to copy
738 mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
739 // number of zero-padded rows above src buffer to copy
740 mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
741 // number of zero-padded rows below src buffer to copy
742 mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
743
744 // number of columns of src buffer to copy
745 mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
746 // number of zero-padded columns before src buffer to copy
747 mov(reg_lov, ptr[param1 + GET_OFF(f_overflow)]);
748 // number of zero-padded columns before src buffer to copy
749 mov(reg_rov, ptr[param1 + GET_OFF(back_overflow)]);
750
751 vpxord(zmm_zero, zmm_zero, zmm_zero);
752
753 { // Handle Left Overflow
754 Label label_lov, label_lov_skip;
755 test(reg_lov, reg_lov);
756 jz(label_lov_skip, T_NEAR);
757 L(label_lov); // handle left or right overflow
758 {
759 Label label_lov_inner;
760 mov(reg_aux_out_ptr, reg_out_ptr);
761 mov(reg_cnt, reg_kht);
762 L(label_lov_inner);
763 {
764 zero_it(reg_aux_out_ptr);
765 add(reg_aux_out_ptr, out_h_step);
766 dec(reg_cnt);
767 jnz(label_lov_inner, T_NEAR);
768 }
769 add(reg_out_ptr, out_w_step);
770 dec(reg_lov);
771 jnz(label_lov, T_NEAR);
772 }
773 L(label_lov_skip);
774 }
775
776 // save output pointer for later use
777 mov(reg_save_out_ptr, reg_out_ptr);
778
779 // just in case there is no meat...
780 Label label_kwp_end;
781 test(reg_kwp, reg_kwp);
782 jz(label_kwp_end, T_NEAR);
783
784 // Unroll over W-dimension in powers of 2
785 Label label_tov;
786 Label label_khp, label_no_khp;
787 Label label_bov;
788 test(reg_tov, reg_tov);
789 jnz(label_tov, T_NEAR);
790 test(reg_khp, reg_khp);
791 jnz(label_khp, T_NEAR);
792 test(reg_bov, reg_bov);
793 jnz(label_bov, T_NEAR);
794 jmp(label_kwp_end, T_NEAR); // safe exit in case of bad parameters
795
796 L(label_tov); // handle top overflow
797 {
798 Label label_tov_inner;
799 mov(reg_aux_out_ptr, reg_out_ptr);
800 mov(reg_cnt, reg_kwp);
801 L(label_tov_inner);
802 {
803 zero_it(reg_aux_out_ptr);
804 add(reg_aux_out_ptr, out_w_step);
805 dec(reg_cnt);
806 jnz(label_tov_inner, T_NEAR);
807 }
808 add(reg_out_ptr, out_h_step);
809 dec(reg_tov);
810 jnz(label_tov, T_NEAR);
811 }
812 test(reg_khp, reg_khp);
813 jz(label_no_khp, T_NEAR);
814 L(label_khp); // handle kh padding (not fully unrolled)
815 {
816 Label label_khp_inner;
817 mov(reg_aux_inp_ptr, reg_inp_ptr);
818 mov(reg_aux_out_ptr, reg_out_ptr);
819 mov(reg_cnt, reg_kwp);
820 L(label_khp_inner);
821 {
822 for (int ic = 0; ic < jcp.ic_without_padding;
823 ic += jcp.ic_block_int) {
824 const int offset = ic * jcp.typesize_in;
825 const bool masked
826 = ic + jcp.ic_block_int > jcp.ic_without_padding;
827 // zero masking is needed to avoid dependency on destination
828 Zmm zmm_load = masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
829 Zmm zmm_store = masked ? zmm_tmp | ktail_mask : zmm_tmp;
830 if (is_bf16) {
831 vmovdqu16(zmm_load, ptr[reg_aux_inp_ptr + offset]);
832 vmovdqu16(ptr[reg_aux_out_ptr + offset], zmm_store);
833 } else {
834 vmovdqu8(zmm_load, ptr[reg_aux_inp_ptr + offset]);
835 vmovdqu8(ptr[reg_aux_out_ptr + offset], zmm_store);
836 }
837 }
838 add(reg_aux_inp_ptr, inp_w_step);
839 add(reg_aux_out_ptr, out_w_step);
840 dec(reg_cnt);
841 jnz(label_khp_inner, T_NEAR);
842 }
843 add(reg_inp_ptr, inp_h_step);
844 add(reg_out_ptr, out_h_step);
845 dec(reg_khp);
846 jnz(label_khp, T_NEAR);
847 }
848 L(label_no_khp);
849 test(reg_bov, reg_bov);
850 jz(label_kwp_end, T_NEAR);
851 L(label_bov); // handle bottom overflow
852 {
853 Label label_bov_inner;
854 mov(reg_aux_out_ptr, reg_out_ptr);
855 mov(reg_cnt, reg_kwp);
856 L(label_bov_inner);
857 {
858 zero_it(reg_aux_out_ptr);
859 add(reg_aux_out_ptr, out_w_step);
860 dec(reg_cnt);
861 jnz(label_bov_inner, T_NEAR);
862 }
863 add(reg_out_ptr, out_h_step);
864 dec(reg_bov);
865 jnz(label_bov, T_NEAR);
866 }
867 L(label_kwp_end);
868
869 { // Handle Right Overflow
870 Label label_rov, label_rov_skip;
871 // retrieve output pointer
872 mov(reg_out_ptr, reg_save_out_ptr);
873 // calculate the shift
874 imul(reg_tmp, reg_kwp, out_w_step);
875 // shift past the body
876 add(reg_out_ptr, reg_tmp);
877 // skip if no right overflow
878 test(reg_rov, reg_rov);
879 jz(label_rov_skip, T_NEAR);
880
881 L(label_rov); // handle left or right overflow
882 {
883 Label label_rov_inner;
884 mov(reg_aux_out_ptr, reg_out_ptr);
885 mov(reg_cnt, reg_kht);
886 L(label_rov_inner);
887 {
888 zero_it(reg_aux_out_ptr);
889 add(reg_aux_out_ptr, out_h_step);
890 dec(reg_cnt);
891 jnz(label_rov_inner, T_NEAR);
892 }
893 add(reg_out_ptr, out_w_step);
894 dec(reg_rov);
895 jnz(label_rov, T_NEAR);
896 }
897 L(label_rov_skip);
898 }
899
900 // For bf16, zero-pad an extra cacheline to avoid NaNs
901 // For int8, it is sufficient to zero-pad the weights only
902 if (is_bf16) {
903 // shift forward to align h index to end of needed buffer
904 imul(reg_tmp, reg_kht, out_h_step);
905 add(reg_out_ptr, reg_tmp);
906 // shift backward to align w index to end of needed buffer
907 sub(reg_out_ptr, out_w_step);
908 vmovdqu16(ptr[reg_out_ptr], zmm_zero);
909 }
910}
911
912void jit_avx512_core_amx_copy_to_pbuffer_t::generate() {
913
914 // Special copy kernel for reduced lowering
915 if (jcp.is_relo) {
916 assert(jcp.nb_ic_int == 1);
917 preamble();
918 copy_row_reduced_lowering();
919 postamble();
920 return;
921 }
922
923 preamble();
924
925 const bool is_3d = jcp.ndims == 5;
926 mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
927 mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
928 if (is_3d) mov(reg_kdp, ptr[param1 + GET_OFF(kd_padding)]);
929 mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
930 mov(reg_tover, ptr[param1 + GET_OFF(t_overflow)]);
931 mov(reg_bover, ptr[param1 + GET_OFF(b_overflow)]);
932 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
933
934 vpxord(zmm_zero, zmm_zero, zmm_zero);
935
936 if (jcp.is_nspc && jcp.ic_without_padding % jcp.ic_block_int) {
937 int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
938 uint64_t mask = (UINT64_C(1) << tail_size) - 1;
939 mov(reg_tmp, mask);
940 kmovq(ktail_mask, reg_tmp);
941 }
942
943 for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
944 Xbyak::Label kd_label, no_kd_label;
945 Xbyak::Label kh_label, no_kh_label, icb_label;
946 Xbyak::Label kh_tover_label, kh_bover_label;
947 Xbyak::Label no_kh_tover_label, no_kh_bover_label;
948
949 mov(reg_aux_inp_ptr, reg_inp_ptr);
950 mov(reg_aux_out_ptr, reg_out_ptr);
951 if (is_3d) {
952 cmp(reg_kdp, 0);
953 jle(no_kd_label, T_NEAR);
954 mov(reg_kdc, reg_kdp);
955 L(kd_label);
956 push(reg_aux_inp_ptr);
957 push(reg_aux_out_ptr);
958 }
959 cmp(reg_khp, 0);
960 jle(no_kh_bover_label, T_NEAR); // nothing to do
961 mov(reg_khc, reg_khp);
962
963 cmp(reg_tover, 0);
964 jle(no_kh_tover_label, T_NEAR);
965
966 mov(reg_kh_over, reg_tover);
967 L(kh_tover_label);
968 {
969 // TODO: adjust step to improve zeroing efficiency for small ic
970 for (int iw = 0; iw < jcp.iwp; iw++)
971 vmovups(ptr[reg_aux_out_ptr
972 + jcp.typesize_in * iw * jcp.ic_block_int_np],
973 zmm_zero);
974 int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
975 add(reg_aux_out_ptr, out_h_offset);
976
977 dec(reg_kh_over);
978 jnz(kh_tover_label, T_NEAR);
979 }
980 sub(reg_khc, reg_tover);
981 L(no_kh_tover_label);
982
983 cmp(reg_khc, reg_bover);
984 jle(no_kh_label, T_NEAR);
985
986 L(kh_label);
987 {
988 copy_row(icb);
989 size_t inp_h_offset = !jcp.is_nspc
990 ? (size_t)jcp.typesize_in * jcp.iw * jcp.ic_block
991 : (size_t)jcp.typesize_in * jcp.iw * jcp.ngroups
992 * jcp.ic_without_padding;
993 size_t out_h_offset
994 = (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
995
996 add(reg_aux_inp_ptr, inp_h_offset);
997 add(reg_aux_out_ptr, out_h_offset);
998
999 dec(reg_khc);
1000 cmp(reg_khc, reg_bover);
1001 jg(kh_label, T_NEAR);
1002 }
1003 L(no_kh_label);
1004
1005 cmp(reg_khc, 0);
1006 jle(no_kh_bover_label, T_NEAR);
1007
1008 L(kh_bover_label);
1009 {
1010 // TODO: adjust step to improve zeroing efficiency for small ic
1011 for (int iw = 0; iw < jcp.iwp; iw++)
1012 vmovups(ptr[reg_aux_out_ptr
1013 + jcp.typesize_in * iw * jcp.ic_block_int_np],
1014 zmm_zero);
1015 int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
1016 add(reg_aux_out_ptr, out_h_offset);
1017
1018 dec(reg_khc);
1019 jnz(kh_bover_label, T_NEAR);
1020 }
1021 size_t out_d_offset = (size_t)jcp.typesize_in
1022 * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1023 L(no_kh_bover_label);
1024 if (is_3d) {
1025 size_t inp_d_offset = !jcp.is_nspc
1026 ? (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block
1027 * (jcp.dilate_d + 1)
1028 : (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ngroups
1029 * jcp.ic_without_padding * (jcp.dilate_d + 1);
1030 pop(reg_aux_out_ptr);
1031 pop(reg_aux_inp_ptr);
1032 add(reg_aux_inp_ptr, inp_d_offset);
1033 add(reg_aux_out_ptr, out_d_offset);
1034 dec(reg_kdc);
1035 jnz(kd_label, T_NEAR);
1036 L(no_kd_label);
1037 }
1038 // End IC Loop
1039 size_t inp_cb_offset = !jcp.is_nspc
1040 ? (size_t)jcp.typesize_in * (jcp.ic_block_int_np / jcp.ic_block)
1041 * jcp.id * jcp.ih * jcp.iw * jcp.ic_block
1042 : (size_t)jcp.typesize_in * jcp.ic_block_int_np;
1043 size_t out_cb_offset = (size_t)jcp.kd * out_d_offset;
1044
1045 add(reg_inp_ptr, inp_cb_offset);
1046 add(reg_out_ptr, out_cb_offset);
1047 }
1048
1049 postamble();
1050}
1051
1052jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t(
1053 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
1054 const memory_desc_t &dst_md)
1055 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
1056 , jcp(ajcp)
1057 , attr_(attr) {
1058 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
1059 using namespace binary_injector;
1060 const auto &rhs_addr_reg = bin_injector_helper_reg_1;
1061 const auto &rhs_helper_reg = bin_injector_helper_reg_2;
1062 const auto &rhs_addr_cache_reg = bin_injector_helper_reg_3;
1063 static constexpr bool preserve_gpr = false;
1064 static constexpr bool preserve_vmm = false;
1065 const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
1066 static constexpr bool use_exact_tail_scalar_bcast = true;
1067
1068 const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
1069 31, rhs_addr_reg, rhs_helper_reg, rhs_addr_cache_reg,
1070 preserve_gpr, preserve_vmm,
1071 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
1072 memory_desc_wrapper(dst_md), tail_size, ktail_mask,
1073 use_exact_tail_scalar_bcast};
1074 const binary_injector::static_params_t static_params {
1075 this->param1, rhs_arg_static_params};
1076
1077 postops_injector_ = utils::make_unique<
1078 injector::jit_uni_postops_injector_t<avx512_core>>(
1079 this, jcp.post_ops, static_params);
1080 }
1081 copy_to_pbuffer_
1082 = utils::make_unique<jit_avx512_core_amx_copy_to_pbuffer_t>(jcp);
1083 if (jcp.is_relo)
1084 copy_to_wbuffer_
1085 = utils::make_unique<jit_avx512_core_amx_copy_to_wbuffer_t>(
1086 jcp);
1087}
1088
1089status_t jit_avx512_core_amx_fwd_kernel_t::create_kernel() {
1090 CHECK(jit_generator::create_kernel());
1091 CHECK(copy_to_pbuffer_->create_kernel());
1092 if (jcp.is_relo) CHECK(copy_to_wbuffer_->create_kernel());
1093 if (jcp.req_zero_point_buffer) {
1094 zp_pbuff_kernel_
1095 = utils::make_unique<jit_avx512_core_amx_compute_zp_pbuff_t>(
1096 jcp);
1097 if (zp_pbuff_kernel_ == nullptr) return status::out_of_memory;
1098 CHECK(zp_pbuff_kernel_->create_kernel());
1099 }
1100 return status::success;
1101}
1102
1103// Tile register decomposition
1104// { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
1105int jit_avx512_core_amx_fwd_kernel_t::get_out_tensor(
1106 int h, int i, bool is_h_tail) const {
1107 const int C_BASE = 0;
1108 const int C_LAST = 4;
1109 assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
1110 MAYBE_UNUSED(C_LAST);
1111 const int tile = C_BASE
1112 + (jcp.nb_oh_blocking > 1
1113 ? h * jcp.nb_oh_blocking + i
1114 : (int)is_h_tail * jcp.nb_oc_blocking + i);
1115 assert(C_BASE <= tile && tile < C_LAST);
1116 return tile;
1117}
1118int jit_avx512_core_amx_fwd_kernel_t::get_inp_tensor(
1119 int h, bool is_h_tail) const {
1120 const int I_BASE = 4;
1121 const int I_LAST = 6;
1122 assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
1123 MAYBE_UNUSED(I_LAST);
1124 const int tile = I_BASE + (jcp.nb_oh_blocking > 1 ? h : (int)is_h_tail);
1125 assert(I_BASE <= tile && tile < I_LAST);
1126 return tile;
1127}
1128int jit_avx512_core_amx_fwd_kernel_t::get_wei_tensor(int i) const {
1129 const int W_BASE = 6;
1130 const int W_LAST = 8;
1131 assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
1132 MAYBE_UNUSED(W_LAST);
1133 const int tile = W_BASE + i;
1134 assert(W_BASE <= tile && tile < W_LAST);
1135 return tile;
1136}
1137
1138// Shifts and offsets
1139size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_icb_step() const {
1140 return (size_t)jcp.kd * get_inp_d_step();
1141}
1142size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_icb_step() const {
1143 return (size_t)jcp.typesize_in * jcp.kd * jcp.kh * jcp.kw
1144 * jcp.ic_block_int_np * jcp.oc_block;
1145}
1146size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_d_step() const {
1147 return (size_t)jcp.typesize_in
1148 * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1149}
1150size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_h_step() const {
1151 return (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np
1152 * (jcp.dilate_h + 1);
1153}
1154size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_d_step() const {
1155 return (size_t)jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block_int_np
1156 * jcp.oc_block;
1157}
1158size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_h_step() const {
1159 return (size_t)jcp.typesize_in * jcp.kw * jcp.ic_block_int_np
1160 * jcp.oc_block;
1161}
1162size_t jit_avx512_core_amx_fwd_kernel_t::get_out_ocb_offset(
1163 int ohb, int ocb, size_t typesize) const {
1164 size_t el_offset = jcp.is_nspc
1165 ? (size_t)ocb * jcp.oc_block
1166 + (size_t)ohb * jcp.ow * jcp.ngroups
1167 * jcp.oc_without_padding
1168 : (size_t)ocb * jcp.oh * jcp.ow * jcp.oc_block
1169 + (size_t)ohb * jcp.ow * jcp.oc_block;
1170 return (size_t)typesize * el_offset;
1171}
1172size_t jit_avx512_core_amx_fwd_kernel_t::get_out_row_offset(
1173 int ohb, int ocb, int j, size_t typesize) const {
1174 size_t offset_w = jcp.is_nspc
1175 ? (size_t)typesize * j * jcp.ngroups * jcp.oc_without_padding
1176 : (size_t)typesize * j * jcp.oc_block;
1177 return get_out_ocb_offset(ohb, ocb, typesize) + offset_w;
1178}
1179size_t jit_avx512_core_amx_fwd_kernel_t::get_out_shift(
1180 int width, size_t typesize) const {
1181 return jcp.is_nspc
1182 ? (size_t)typesize * width * jcp.ngroups * jcp.oc_without_padding
1183 : (size_t)typesize * width * jcp.oc_block;
1184}
1185size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_ocb_offset(
1186 int ohb, int ocb) const {
1187 size_t el_offset = (size_t)ocb * prv_width_ * jcp.oc_block
1188 + (size_t)ohb * jcp.nb_oc_blocking * jcp.full_tile_width
1189 * jcp.oc_block;
1190 return jcp.typesize_acc * el_offset;
1191}
1192size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_row_offset(
1193 int ohb, int ocb, int j) const {
1194 return get_wsp_ocb_offset(ohb, ocb)
1195 + (size_t)jcp.typesize_acc * j * jcp.oc_block;
1196}
1197size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_shift() const {
1198 return (size_t)jcp.typesize_acc * jcp.nb_oh_blocking * jcp.full_tile_width
1199 * jcp.oc_block * jcp.nb_oc_blocking;
1200}
1201size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_offset(int ocb, int kw) const {
1202 size_t el_offset = (size_t)kw * jcp.ic_block_int_np * jcp.oc_block;
1203 size_t raw_oc_subblock_step
1204 = jcp.kd * jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block;
1205 size_t oc_subblock_step = jcp.is_relo
1206 ? rnd_up(raw_oc_subblock_step, jcp.ic_block_int * jcp.oc_block)
1207 : raw_oc_subblock_step;
1208 el_offset += (size_t)ocb * jcp.nb_ic_int * oc_subblock_step;
1209 return jcp.typesize_in * el_offset;
1210}
1211size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_shift() const {
1212 size_t w_step = (jcp.is_relo ? jcp.stride_w * jcp.kh
1213 : jcp.is_pbuffer_strided ? 1 : jcp.stride_w)
1214 * jcp.ic_block_int_np;
1215 return (size_t)jcp.typesize_in * jcp.tile_width * w_step;
1216}
1217size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_offset(int ohb, int kw) const {
1218 if (jcp.is_relo)
1219 return ohb * jcp.iwp * jcp.kh * jcp.ic_block_int_np * jcp.typesize_in;
1220 // calculate offset by height dimension
1221 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
1222 const int gen_stride_h = nstl::min(jcp.stride_h, gen_kh);
1223 size_t el_offset = (size_t)ohb * jcp.oh_per_tile * gen_stride_h * jcp.iwp
1224 * jcp.ic_block_int_np;
1225
1226 // add offset by width dimension
1227 if (IMPLICATION(jcp.is_pbuffer_strided, jcp.stride_w == 1)) {
1228 el_offset += (size_t)kw * (jcp.dilate_w + 1) * jcp.ic_block_int_np;
1229 } else if (jcp.dilate_w > 0) {
1230 el_offset += (size_t)kw * jcp.ow_block * jcp.ic_block_int_np;
1231 } else {
1232 // dilate_w == 0 && stride_w > 1
1233 // there are min(jcp.kw, jcp.stride_w) continuous sets of input data
1234 // (foreach stride idx), they are placed one by one without additional
1235 // padding
1236
1237 // calculate set idx for current kw value
1238 int set_idx = kw % jcp.stride_w;
1239 // calculate shift within set for current kw value
1240 int set_shift = kw / jcp.stride_w;
1241
1242 // calculate the beginning of the current set along width, each set
1243 // with index set_i contains number of elements along width equal to
1244 // jcp.ow - 1 + jcp.kw / jcp.stride_w
1245 // + (set_i < jcp.kw % jcp.stride_w)
1246 size_t set_start = (jcp.ow_block - 1 + jcp.kw / jcp.stride_w) * set_idx
1247 + nstl::min(set_idx, jcp.kw % jcp.stride_w);
1248 el_offset += (set_start + set_shift) * jcp.ic_block_int_np;
1249 }
1250 return jcp.typesize_in * el_offset;
1251}
1252
1253size_t jit_avx512_core_amx_fwd_kernel_t::get_zp_comp_offset(
1254 int ocb, int zp_h, int zp_w) const {
1255 const size_t ocb_offset = (size_t)ocb * jcp.oc_block;
1256 const size_t sp_offset = (size_t)(zp_h * jcp.ow_pad + zp_w) * jcp.ngroups
1257 * jcp.oc_without_padding;
1258 return (ocb_offset + sp_offset) * sizeof(int32_t);
1259}
1260
1261int jit_avx512_core_amx_fwd_kernel_t::get_zp_index_offset(
1262 int index, int mid, int s_pad_output, int e_pad_output) {
1263 using namespace nstl;
1264 const int mid_end = e_pad_output - 1;
1265 int zp_mid = min(mid, max(0, index - mid_end));
1266 int zp_pad_offset
1267 = accum_with_upper_bound(index, s_pad_output, e_pad_output);
1268 return zp_pad_offset + zp_mid;
1269}
1270
1271// Code generation
1272void jit_avx512_core_amx_fwd_kernel_t::prepare_output(int tail) {
1273 for (int h = 0; h < jcp.nb_oh_blocking; h++)
1274 for (int i = 0; i < jcp.nb_oc_blocking; i++)
1275 tilezero(Tmm(get_out_tensor(h, i, tail)));
1276}
1277
1278void jit_avx512_core_amx_fwd_kernel_t::init_runtime_counters(
1279 bool start_with_last_tile_block) {
1280 prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
1281 ? jcp.tile_tail
1282 : jcp.tile_width;
1283
1284 row_count_ = 0;
1285 is_store_done_ = false;
1286 is_buffer_empty_ = true;
1287}
1288
1289size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_block(
1290 const int block_size, const int pad_output) {
1291 return (size_t)(pad_output >= block_size ? block_size : 0)
1292 + (pad_output % block_size);
1293}
1294
1295size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_blocked_dims(
1296 const int dim_size, const int block_size, const int s_pad_output,
1297 const int e_pad_output) {
1298 using namespace nstl;
1299
1300 // start padding (s_pad)
1301 int s_pad_limit = reduce_to_block(block_size, s_pad_output);
1302 int s_pad_area_blk = rnd_up(s_pad_limit, block_size);
1303
1304 // middle (no padding)
1305 int no_pad_area = max(
1306 0, dim_size - rnd_up(s_pad_output, block_size) - e_pad_output);
1307 int no_pad_limit = (no_pad_area >= block_size ? block_size : 0);
1308
1309 // end padding (e_pad)
1310 int no_pad_area_shift = no_pad_area % block_size;
1311 int e_pad_area_overlap
1312 = no_pad_area_shift == 0 ? 0 : block_size - no_pad_area_shift;
1313 // middle and end padding shift
1314 int e_pad_shift_limit
1315 = no_pad_area_shift + min(e_pad_output, e_pad_area_overlap);
1316 int e_pad_area_blk = max(0, e_pad_output - e_pad_area_overlap);
1317 // full end padding block
1318 int e_pad_limit = reduce_to_block(block_size, e_pad_area_blk);
1319
1320 // calculate reduced size of s_pad, middle and e_pad blocks.
1321 return min((size_t)dim_size,
1322 (size_t)s_pad_area_blk + no_pad_limit + e_pad_shift_limit
1323 + e_pad_limit);
1324}
1325
1326Ymm jit_avx512_core_amx_fwd_kernel_t::ymm_mask(
1327 const Ymm &ymm_in, bool mask_flag, bool store) {
1328 return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
1329 : ymm_in;
1330}
1331
1332Zmm jit_avx512_core_amx_fwd_kernel_t::zmm_mask(
1333 const Zmm &zmm_in, bool mask_flag, bool store) {
1334 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
1335 : zmm_in;
1336}
1337
1338void jit_avx512_core_amx_fwd_kernel_t::cvt2ps(data_type_t type_in,
1339 const Zmm &zmm_in, const Operand &op, bool mask_flag) {
1340 const Zmm zmm = zmm_mask(zmm_in, mask_flag);
1341 switch (type_in) {
1342 case data_type::f32:
1343 case data_type::s32: vmovups(zmm, op); break;
1344 case data_type::s8: vpmovsxbd(zmm, op); break;
1345 case data_type::u8: vpmovzxbd(zmm, op); break;
1346 case data_type::bf16:
1347 vpmovzxwd(zmm, op);
1348 vpslld(zmm, zmm, 0x10);
1349 break;
1350 default: assert(!"unsupported data type");
1351 }
1352 if (!utils::one_of(type_in, data_type::f32, data_type::bf16))
1353 vcvtdq2ps(zmm_in, zmm_in);
1354}
1355
1356void jit_avx512_core_amx_fwd_kernel_t::apply_sum(const Zmm &zmm_out,
1357 const float *p_sum_scale, const int32_t *p_sum_zp,
1358 const Xbyak::Address &addr, const bool mask_flag) {
1359 if (p_sum_scale) {
1360 const float p_sum_scale_val = *p_sum_scale;
1361 const int32_t p_sum_zp_val = *p_sum_zp;
1362 const auto sum_injector = [&, p_sum_scale_val, p_sum_zp_val,
1363 mask_flag]() {
1364 cvt2ps(jcp.sum_dt, zmm_prev_dst, addr, mask_flag);
1365 if (p_sum_zp_val != 0) {
1366 vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
1367 vsubps(zmm_prev_dst, zmm_sum_zp);
1368 }
1369 if (p_sum_scale_val == 1.f)
1370 vaddps(zmm_out, zmm_prev_dst);
1371 else
1372 vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
1373 };
1374 postops_injector_->set_lambda_injector(
1375 primitive_kind::sum, sum_injector);
1376 }
1377}
1378
1379void jit_avx512_core_amx_fwd_kernel_t::apply_postops(const Zmm &zmm_out,
1380 const float *p_sum_scale, const int32_t *p_sum_zp,
1381 const Xbyak::Address &addr, const size_t off, const bool mask_flag) {
1382 if (jcp.with_eltwise || jcp.with_binary
1383 || (jcp.with_sum && p_sum_scale != nullptr)) {
1384 apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag);
1385
1386 const auto vmm_idx = zmm_out.getIdx();
1387 if (jcp.with_binary) {
1388 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1389 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_out_ptr);
1390 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off);
1391 if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1392
1393 postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
1394 } else {
1395 postops_injector_->compute_vector(vmm_idx);
1396 }
1397 }
1398}
1399
1400void jit_avx512_core_amx_fwd_kernel_t::store_output_ymm_bf16(
1401 const int idx, const Xbyak::Address &addr, const bool mask_flag) {
1402 Ymm ymm_out = Ymm(idx);
1403 vcvtneps2bf16(ymm_out, Zmm(idx));
1404 vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
1405}
1406
1407void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_bf16(
1408 const Zmm &zmm_out, int ocb, int h, int w) {
1409 const bool mask_flag = jcp.is_nspc && jcp.oc_without_padding != jcp.oc
1410 && ocb == (jcp.nb_oc_blocking - 1);
1411
1412 const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1413 auto addr = EVEX_compress_addr(reg_out_ptr, off);
1414
1415 const auto &p = attr_.post_ops_;
1416
1417 const int sum_idx = p.find(primitive_kind::sum);
1418 if (sum_idx != -1) {
1419 if (jcp.dst_dt == data_type::bf16) {
1420 vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
1421 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
1422 vaddps(zmm_out, zmm_prev_dst);
1423 } else {
1424 vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
1425 vaddps(zmm_out, zmm_prev_dst);
1426 }
1427 }
1428 if (jcp.with_bias) {
1429 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
1430 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1431 if (jcp.bia_dt == data_type::bf16) {
1432 vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
1433 vpslld(zmm_bias, zmm_bias, 16);
1434 vaddps(zmm_out, zmm_bias);
1435 } else
1436 vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
1437 }
1438
1439 static constexpr auto skip_sum_injection = nullptr;
1440 apply_postops(zmm_out, skip_sum_injection, skip_sum_injection, addr, off,
1441 mask_flag);
1442
1443 if (jcp.dst_dt == data_type::bf16) {
1444 store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag);
1445 } else {
1446 vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
1447 }
1448}
1449
1450void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_int8(
1451 const Zmm &zmm_out, int ocb, int h, int w, const bool compute_zp,
1452 const int zp_h, const int zp_w) {
1453 const int nb_oc_block = jcp.nb_oc_blocking;
1454 const int oc_block = jcp.oc_block;
1455 const bool mask_flag = true && jcp.oc_without_padding != jcp.oc
1456 && ocb == (nb_oc_block - 1);
1457
1458 const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1459 auto addr = EVEX_compress_addr(reg_out_ptr, off);
1460
1461 const auto &p = attr_.post_ops_;
1462 const int sum_idx = p.find(primitive_kind::sum);
1463 const float *p_sum_scale = nullptr;
1464 const int32_t *p_sum_zp = nullptr;
1465 if (sum_idx != -1) {
1466 const auto &p_entry = p.entry_[sum_idx];
1467 p_sum_scale = &p_entry.sum.scale;
1468 p_sum_zp = &p_entry.sum.zero_point;
1469 }
1470
1471 if (p_sum_scale) {
1472 if (*p_sum_scale != 1.f)
1473 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
1474 if (*p_sum_zp != 0)
1475 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
1476 }
1477
1478 int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * oc_block);
1479 if (jcp.with_bias) {
1480 int bias_offset = jcp.typesize_bia * ocb * oc_block;
1481 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1482 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
1483 }
1484 if (compute_zp) {
1485 assert(jcp.req_zero_point_buffer);
1486 // add zero-point padding compensation when accum data is S32
1487 const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1488 vmovups(m_zmm_zp,
1489 EVEX_compress_addr(reg_zero_point_pbuff,
1490 get_zp_comp_offset(ocb, zp_h, zp_w)));
1491 const Zmm m_zmm_out = zmm_mask(zmm_out, mask_flag);
1492 vpaddd(m_zmm_out, zmm_out, zmm_zp);
1493 }
1494 if (jcp.src_zero_point) {
1495 // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
1496 int zp_offset = sizeof(int32_t) * ocb * oc_block;
1497 const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1498 vpmulld(m_zmm_zp, zmm_src_zp,
1499 EVEX_compress_addr(reg_zp_compensation, zp_offset));
1500 vpaddd(zmm_out, zmm_out, zmm_zp);
1501 }
1502
1503 /* add bias and zero-point to zmm_accum */
1504 vcvtdq2ps(zmm_out, zmm_out);
1505 const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
1506 vmulps(zmm_out_msk, zmm_out,
1507 EVEX_compress_addr(reg_ptr_scales, scale_offset));
1508 if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
1509
1510 apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag);
1511
1512 if (jcp.dst_scale) { vmulps(zmm_out_msk, zmm_out, zmm_dst_scale); }
1513 if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); }
1514
1515 // Properly saturate the accumulators for integer datatypes
1516 if (one_of(jcp.dst_dt, u8, s8, s32)) {
1517 init_saturate_f32(
1518 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dst_dt);
1519 saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dst_dt);
1520 vcvtps2dq(zmm_out, zmm_out);
1521 }
1522
1523 const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
1524
1525 switch (jcp.dst_dt) {
1526 case data_type::f32:
1527 case data_type::s32: vmovups(addr, zmm_out_store); break;
1528 case data_type::bf16:
1529 store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag);
1530 break;
1531 case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
1532 case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
1533 default: assert(!"unknown dst_dt");
1534 }
1535}
1536
1537void jit_avx512_core_amx_fwd_kernel_t::store_output_vector(const Zmm &zmm_out,
1538 int ocb, int h, int w, const bool compute_zp, const int zp_h,
1539 const int zp_w) {
1540 /*
1541 Output:
1542 jcp.is_nspc !jcp.is_nspc
1543 --------------------- ---------------------
1544 INT8: [N][H][W][NBOC][16OC]
1545 BF16: [N][H][W][NBOC][16OC] or [N][NBOC][H][W][16OC]
1546 */
1547 if (jcp.src_dt == data_type::bf16) {
1548 store_output_vector_bf16(zmm_out, ocb, h, w);
1549 } else {
1550 store_output_vector_int8(zmm_out, ocb, h, w, compute_zp, zp_h, zp_w);
1551 }
1552}
1553
1554void jit_avx512_core_amx_fwd_kernel_t::store_output(int width, int tail,
1555 bool do_store, const bool handle_h_blk, const int t_pad_output,
1556 const int b_pad_output, const int l_pad_output, const int r_pad_output,
1557 const bool is_last_oh_block, const bool zp_3d_pad) {
1558 auto store_output_block = [=](int width, int tail, bool do_store,
1559 bool is_last_h = false) {
1560 // Calculate the number of oh blocks; it may differ on last call
1561 const int last_h_blks
1562 = div_up(jcp.oh, jcp.oh_per_tile) % jcp.nb_oh_blocking;
1563 const int h_blks = is_last_h && last_h_blks != 0 ? last_h_blks
1564 : jcp.nb_oh_blocking;
1565 // Calculate the number of oh rows per tile; it may differ on last call
1566 const int h_tail = is_last_h && jcp.oh % jcp.oh_per_tile != 0
1567 ? (h_blks - 1) * jcp.oh_per_tile + jcp.oh % jcp.oh_per_tile
1568 : h_blks * jcp.oh_per_tile;
1569 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
1570 const int owp = gen_kw + jcp.ow - 1;
1571
1572 if (jcp.dst_scale) {
1573 mov(reg_dst_scale, ptr[param1 + GET_OFF(dst_scale)]);
1574 vmovups(zmm_dst_scale, EVEX_compress_addr(reg_dst_scale, 0));
1575 }
1576 if (jcp.src_zero_point) {
1577 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1578 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1579 vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1580 }
1581 if (jcp.dst_zero_point) {
1582 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1583 vcvtdq2ps(zmm_dst_zp,
1584 EVEX_compress_addr(reg_dst_zero_point, 0, true));
1585 }
1586
1587 for_(int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
1588 for (int ohb = 0; ohb < h_blks; ohb++) {
1589 /* Formats: Workspace: [NBOC][W][16OC] */
1590 tilestored(ptr[reg_wsp_ptr + reg_wei_stride
1591 + get_wsp_ocb_offset(ohb, ocb)],
1592 Tmm(get_out_tensor(ohb, ocb, tail)));
1593 is_buffer_empty_ = false;
1594 is_store_done_ = false;
1595
1596 // preserve registers used by binary post_ops injector
1597 const injector_utils::conditional_register_preserve_guard_t
1598 cond_register_guard(jcp.with_binary, this,
1599 {bin_injector_helper_reg_1,
1600 bin_injector_helper_reg_2,
1601 bin_injector_helper_reg_3});
1602
1603 for (int tw = 0; tw < width && do_store; tw++) {
1604 // height
1605 const int oh_index = ohb * jcp.oh_per_tile + tw / owp;
1606 const bool zp_h_pad
1607 = oh_index < t_pad_output || oh_index >= b_pad_output;
1608 const int zp_h = get_zp_index_offset(
1609 oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1610 // width
1611 const int ow_index = tw % owp;
1612 const bool zp_w_pad
1613 = ow_index < l_pad_output || ow_index >= r_pad_output;
1614 const int zp_w = get_zp_index_offset(
1615 ow_index, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1616
1617 const bool compute_zp = jcp.req_zero_point_buffer
1618 && (zp_3d_pad || zp_w_pad || zp_h_pad);
1619
1620 assert(IMPLICATION(jcp.oh_per_tile == 1,
1621 ohb == oh_index && tw == ow_index));
1622 if (oh_index < h_tail && ow_index < jcp.ow) {
1623 Zmm zmm_r = zmm_out(tw);
1624 vmovups(zmm_r,
1625 ptr[reg_wsp_ptr
1626 + get_wsp_row_offset(ohb, ocb, tw)]);
1627 store_output_vector(zmm_r, ocb, oh_index, ow_index,
1628 compute_zp, zp_h, zp_w);
1629 }
1630 }
1631 }
1632 };
1633
1634 // adjustment in case interleave store is turned off
1635 do_store = do_store || jcp.per_one_pstore == 0;
1636 if (!do_store) { w_padding.emplace(l_pad_output, r_pad_output); }
1637 if (!handle_h_blk) {
1638 store_output_block(width, tail, do_store, is_last_oh_block);
1639 } else {
1640 if (jcp.oh % (jcp.oh_per_tile * jcp.nb_oh_blocking) == 0) {
1641 store_output_block(width, tail, do_store);
1642 } else {
1643 Label label_oh_oc_store, label_done;
1644 mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
1645 cmp(reg_last_h, 0);
1646 jne(label_oh_oc_store, T_NEAR);
1647 store_output_block(width, tail, do_store, true); // last h
1648 jmp(label_done, T_NEAR);
1649 L(label_oh_oc_store);
1650 store_output_block(width, tail, do_store, false);
1651 L(label_done);
1652 }
1653 }
1654 if (do_store) {
1655 add(reg_out_ptr, get_out_shift(width, jcp.typesize_out));
1656 if (jcp.req_zero_point_buffer) {
1657 const size_t sp_shift
1658 = accum_with_upper_bound(width, l_pad_output, r_pad_output);
1659 add(reg_zero_point_pbuff, get_out_shift(sp_shift, sizeof(int32_t)));
1660 }
1661 }
1662}
1663
1664void jit_avx512_core_amx_fwd_kernel_t::interleave_store(int width,
1665 int const t_pad_output, int const b_pad_output, const bool zp_3d_pad) {
1666 for (int c = 0;
1667 c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
1668 c++) {
1669 // row_count = ohb * OCB * TW + ocb * TW + tw
1670 int tw = row_count_ % prv_width_;
1671 int ocb = (row_count_ / prv_width_) % jcp.nb_oc_blocking;
1672 int ohb = (row_count_ / prv_width_) / jcp.nb_oc_blocking;
1673
1674 // preserve registers used by binary post_ops injector
1675 const injector_utils::conditional_register_preserve_guard_t
1676 cond_register_guard(jcp.with_binary, this,
1677 {bin_injector_helper_reg_1, bin_injector_helper_reg_2});
1678
1679 // height
1680 const int oh_index = ohb;
1681 const bool zp_h_pad
1682 = oh_index < t_pad_output || oh_index >= b_pad_output;
1683 const int zp_h = get_zp_index_offset(
1684 oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1685 // width
1686 const int l_pad_output
1687 = w_padding.empty() ? 0 : w_padding.front().l_pad_output;
1688 const int r_pad_output
1689 = w_padding.empty() ? jcp.ow : w_padding.front().r_pad_output;
1690
1691 const bool zp_w_pad = tw < l_pad_output || tw >= r_pad_output;
1692 const int zp_w = get_zp_index_offset(
1693 tw, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1694
1695 const bool compute_zp = jcp.req_zero_point_buffer
1696 && (zp_3d_pad || zp_w_pad || zp_h_pad);
1697
1698 Zmm zmm_r = zmm_out(tw);
1699 vmovups(zmm_r, ptr[reg_wsp_ptr + get_wsp_row_offset(ohb, ocb, tw)]);
1700 store_output_vector(zmm_r, ocb, ohb, tw, compute_zp, zp_h, zp_w);
1701 row_count_++;
1702
1703 if (row_count_
1704 == prv_width_ * jcp.nb_oc_blocking * jcp.nb_oh_blocking) {
1705 add(reg_out_ptr, get_out_shift(prv_width_, jcp.typesize_out));
1706 if (jcp.req_zero_point_buffer) {
1707 const size_t sp_shift = accum_with_upper_bound(
1708 prv_width_, l_pad_output, r_pad_output);
1709 add(reg_zero_point_pbuff,
1710 get_out_shift(sp_shift, sizeof(int32_t)));
1711 if (!w_padding.empty()) w_padding.pop();
1712 }
1713 row_count_ = 0;
1714 is_store_done_ = true;
1715 prv_width_ = width;
1716 }
1717 }
1718}
1719
1720void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width,
1721 bool do_store, const bool handle_h_blk, const int t_pad_output,
1722 const int b_pad_output, const int l_pad_output, const int r_pad_output,
1723 const bool zp_3d_pad, const bool is_last_oh_block) {
1724 const bool tail = width == jcp.tile_tail;
1725
1726 auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
1727 if (jcp.src_dt == data_type::bf16 && jcp.wei_dt == data_type::bf16) {
1728 tdpbf16ps(x1, x2, x3);
1729 } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::u8) {
1730 tdpbuud(x1, x2, x3);
1731 } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::s8) {
1732 tdpbusd(x1, x2, x3);
1733 } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::u8) {
1734 tdpbsud(x1, x2, x3);
1735 } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::s8) {
1736 tdpbssd(x1, x2, x3);
1737 } else {
1738 assert(!"unsupported combination");
1739 }
1740 };
1741
1742 prepare_output(tail);
1743
1744 // prepare registers for when 'interleave_store()' is computed
1745 if (jcp.dst_scale) {
1746 mov(reg_dst_scale, ptr[param1 + GET_OFF(dst_scale)]);
1747 vmovups(zmm_dst_scale, EVEX_compress_addr(reg_dst_scale, 0));
1748 }
1749 if (jcp.src_zero_point) {
1750 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1751 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1752 vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1753 }
1754 if (jcp.dst_zero_point) {
1755 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1756 vcvtdq2ps(zmm_dst_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true));
1757 }
1758
1759 // reduced lowering path
1760 if (jcp.is_relo) {
1761 const int nreduce = jcp.nreduce;
1762 const int stride = jcp.ic_block_int; // ie 64 (32) for int8 (bf16)
1763
1764 push(reg_inp_ptr);
1765 push(reg_wei_ptr);
1766
1767 for (int ireduce = 0; ireduce < nreduce; ireduce += stride) {
1768 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1769 tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1770 ptr[reg_inp_ptr + get_inp_offset(ohb, 0)
1771 + reg_inp_stride]);
1772 }
1773 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1774 tileloadd(Tmm(get_wei_tensor(ocb)),
1775 ptr[reg_wei_ptr + get_wei_offset(ocb, 0)
1776 + reg_wei_stride]);
1777 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1778 tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1779 Tmm(get_inp_tensor(ohb, tail)),
1780 Tmm(get_wei_tensor(ocb)));
1781 interleave_store(width, t_pad_output, b_pad_output);
1782 }
1783 }
1784 if (ireduce + stride < nreduce) {
1785 add(reg_inp_ptr, stride * jcp.typesize_in);
1786 add(reg_wei_ptr, stride * jcp.oc_block * jcp.typesize_in);
1787 }
1788 }
1789 pop(reg_wei_ptr);
1790 pop(reg_inp_ptr);
1791
1792 store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1793 b_pad_output, l_pad_output, r_pad_output, is_last_oh_block);
1794
1795 add(reg_inp_ptr, get_inp_shift());
1796
1797 return;
1798 }
1799
1800 auto wei_offset = [&](int icb, int ocb, int kd, int kh, int kw) {
1801 return (size_t)icb * get_wei_icb_step() + kd * get_wei_d_step()
1802 + kh * get_wei_h_step() + get_wei_offset(ocb, kw);
1803 };
1804
1805 auto inp_offset = [&](int icb, int ohb, int kd, int kh, int kw) {
1806 return (size_t)icb * get_inp_icb_step() + kd * get_inp_d_step()
1807 + kh * get_inp_h_step() + get_inp_offset(ohb, kw);
1808 };
1809
1810 auto safe_tileloadd
1811 = [=](const Tmm &t1, const Xbyak::Reg64 &reg_ptr, size_t offset,
1812 const Xbyak::Reg64 &reg_stride) {
1813 if (offset <= INT32_MAX) {
1814 tileloadd(t1, ptr[reg_ptr + offset + reg_stride]);
1815 } else {
1816 safe_add(reg_ptr, offset, reg_tmp);
1817 tileloadd(t1, ptr[reg_ptr + reg_stride]);
1818 safe_sub(reg_ptr, offset, reg_tmp);
1819 }
1820 };
1821
1822 // normal and k-remainders path
1823 const bool check_kd_padding
1824 = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
1825 for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
1826 Label kd_skip_compute;
1827 if (check_kd_padding) mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1828
1829 for (int kd = 0; kd < jcp.kd; kd++) {
1830 if (check_kd_padding) {
1831 dec(reg_kd);
1832 jl(kd_skip_compute, T_NEAR);
1833 push(reg_kd);
1834 }
1835 for (int kh = 0; kh < jcp.kh; kh++) {
1836 for (int set_idx = 0; set_idx < jcp.n_stride_sets;
1837 set_idx++) { // used to optimize input memory reuse in L1$
1838 for (int kw = set_idx; kw < jcp.kw; kw += jcp.kw_step) {
1839 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1840 const size_t inp_off
1841 = inp_offset(icb, ohb, kd, kh, kw);
1842 safe_tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1843 reg_inp_ptr, inp_off, reg_inp_stride);
1844 }
1845 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1846 const size_t wei_off
1847 = wei_offset(icb, ocb, kd, kh, kw);
1848 safe_tileloadd(Tmm(get_wei_tensor(ocb)),
1849 reg_wei_ptr, wei_off, reg_wei_stride);
1850 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1851 tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1852 Tmm(get_inp_tensor(ohb, tail)),
1853 Tmm(get_wei_tensor(ocb)));
1854 interleave_store(width, t_pad_output,
1855 b_pad_output, zp_3d_pad);
1856 }
1857 }
1858 }
1859 }
1860 }
1861 if (check_kd_padding) pop(reg_kd);
1862 }
1863 L(kd_skip_compute);
1864 }
1865
1866 store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1867 b_pad_output, l_pad_output, r_pad_output, is_last_oh_block,
1868 zp_3d_pad);
1869
1870 add(reg_inp_ptr, get_inp_shift());
1871}
1872
1873void jit_avx512_core_amx_fwd_kernel_t::dispatch_icb_loop(int width,
1874 bool do_store, const int l_pad_output, const int r_pad_output,
1875 const bool zp_3d_pad) {
1876 if (jcp.req_zero_point_buffer
1877 && (jcp.t_pad_output > 0 || jcp.b_pad_output > 0)) {
1878 const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
1879 const size_t height_limit = reduce_to_blocked_dims(
1880 jcp.oh, oh_step_size, jcp.t_pad_output, jcp.b_pad_output);
1881 const int ur_h = div_up(height_limit, oh_step_size);
1882 assert(6 >= ur_h);
1883
1884 // Use a jump-table to execute the corresponding block
1885 Label h_blk_label[6], h_blk_end_label, jmp_table_label;
1886 mov(reg_jmp_blk, ptr[param1 + GET_OFF(ohb)]);
1887 mov(reg_tmp, jmp_table_label);
1888 jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
1889 jmp(h_blk_end_label, T_NEAR); // error, shouldn't happen
1890
1891 align(8);
1892 L(jmp_table_label);
1893 for (int u = 0; u < ur_h; ++u) {
1894 putL(h_blk_label[u]);
1895 }
1896
1897 // Save value of global variables for the next 'h_blk' iteration
1898 const int local_prv_width = prv_width_;
1899 const int local_row_count = row_count_;
1900 const bool local_is_store_done = is_store_done_;
1901 const bool local_is_buffer_empty = is_buffer_empty_;
1902
1903 // Unroll ow_block with regards to l_pad_output and r_pad_output
1904 int cur_t_pad = reduce_to_block(oh_step_size, jcp.t_pad_output);
1905 int cur_b_pad = height_limit
1906 - reduce_to_block(oh_step_size, jcp.b_pad_output);
1907 for (int u = 0; u < ur_h; u++) {
1908 bool last = u == ur_h - 1;
1909 L(h_blk_label[u]);
1910
1911 // restore to previous 'h_blk' state of variables
1912 prv_width_ = local_prv_width;
1913 row_count_ = local_row_count;
1914 is_store_done_ = local_is_store_done;
1915 is_buffer_empty_ = local_is_buffer_empty;
1916 compute_icb_loop(width, do_store, false, cur_t_pad, cur_b_pad,
1917 l_pad_output, r_pad_output, zp_3d_pad, last);
1918 cur_t_pad = nstl::max(0, cur_t_pad - oh_step_size);
1919 cur_b_pad = nstl::max(0, cur_b_pad - oh_step_size);
1920 if (!last) jmp(h_blk_end_label, T_NEAR);
1921 }
1922 L(h_blk_end_label);
1923 } else {
1924 compute_icb_loop(width, do_store, true, 0, jcp.oh, l_pad_output,
1925 r_pad_output, zp_3d_pad);
1926 }
1927}
1928
1929void jit_avx512_core_amx_fwd_kernel_t::dispatch_zp_3d_compute(int width,
1930 bool do_store, const int l_pad_output, const int r_pad_output) {
1931 if (jcp.req_zero_point_buffer && (jcp.f_pad > 0 || jcp.back_pad > 0)) {
1932 Label compute_3d_zp_label, zp_d_end_label;
1933 mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1934 cmp(reg_kd, jcp.kd);
1935 jne(compute_3d_zp_label, T_NEAR);
1936
1937 // Save value of global variables for next 'dispatch_icb_loop'
1938 const int local_prv_width = prv_width_;
1939 const int local_row_count = row_count_;
1940 const bool local_is_store_done = is_store_done_;
1941 const bool local_is_buffer_empty = is_buffer_empty_;
1942 dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1943
1944 jmp(zp_d_end_label, T_NEAR);
1945 L(compute_3d_zp_label);
1946
1947 prv_width_ = local_prv_width;
1948 row_count_ = local_row_count;
1949 is_store_done_ = local_is_store_done;
1950 is_buffer_empty_ = local_is_buffer_empty;
1951 dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, true);
1952
1953 L(zp_d_end_label);
1954 } else
1955 dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1956}
1957
1958void jit_avx512_core_amx_fwd_kernel_t::compute_ow_loop() {
1959 auto compute_ow_loop_body = [=](bool last_owb, int num_tile_blocks,
1960 const int l_pad_output,
1961 const int r_pad_output) {
1962 int cur_l_pad_output = l_pad_output;
1963 int cur_r_pad_output = r_pad_output;
1964 int gen_tile_tail = last_owb && jcp.tile_tail > 0 ? jcp.tile_tail
1965 : jcp.tile_width;
1966 init_runtime_counters(last_owb && num_tile_blocks == 1);
1967 for (int owb = 0; owb < num_tile_blocks - 1; owb++) {
1968 dispatch_zp_3d_compute(
1969 jcp.tile_width, false, cur_l_pad_output, cur_r_pad_output);
1970 cur_l_pad_output = nstl::max(0, cur_l_pad_output - jcp.tile_width);
1971 cur_r_pad_output = nstl::max(0, cur_r_pad_output - jcp.tile_width);
1972 }
1973 dispatch_zp_3d_compute(
1974 gen_tile_tail, true, cur_l_pad_output, cur_r_pad_output);
1975 };
1976
1977 assert(jcp.nb_ow > 0);
1978 if (jcp.nb_ow == 1) {
1979 const int ow_r_pad_start
1980 = nstl::max(jcp.ow - jcp.r_pad_output, jcp.l_pad_output);
1981 compute_ow_loop_body(
1982 true, jcp.ow_blocks, jcp.l_pad_output, ow_r_pad_start);
1983 } else if (jcp.req_zero_point_buffer
1984 && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0)) {
1985
1986 const size_t zp_addr_shift
1987 = jcp.ngroups * jcp.oc_without_padding * sizeof(int32_t);
1988 const int ow_step_size = jcp.ow_block;
1989 const int ow_blocks_per_call = div_up(ow_step_size, jcp.tile_width);
1990 const int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call == 0
1991 ? ow_blocks_per_call
1992 : jcp.ow_blocks % ow_blocks_per_call;
1993 const int width_limit = reduce_to_blocked_dims(
1994 jcp.ow, ow_step_size, jcp.l_pad_output, jcp.r_pad_output);
1995 const int ur_w = div_up(width_limit, ow_step_size);
1996 assert(6 >= ur_w);
1997 // Use a jump-table to execute the corresponding block
1998 Label w_blk_label[6], w_blk_end_label, jmp_table_label;
1999 mov(reg_jmp_blk, ptr[param1 + GET_OFF(owb)]);
2000 mov(reg_tmp, jmp_table_label);
2001 jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
2002 jmp(w_blk_end_label, T_NEAR); // error, shouldn't happen
2003
2004 align(8);
2005 L(jmp_table_label);
2006 for (int u = 0; u < ur_w; ++u) {
2007 putL(w_blk_label[u]);
2008 }
2009
2010 // Unroll ow_block with regards to l_pad_output and r_pad_output
2011 int cur_l_pad = reduce_to_block(ow_step_size, jcp.l_pad_output);
2012 int cur_r_pad
2013 = width_limit - reduce_to_block(ow_step_size, jcp.r_pad_output);
2014 int zp_offset = 0;
2015 for (int u = 0; u < ur_w; u++) {
2016 const bool last = u == ur_w - 1;
2017 L(w_blk_label[u]);
2018 if (u > 0) add(reg_zero_point_pbuff, zp_offset * zp_addr_shift);
2019 compute_ow_loop_body(last,
2020 last ? last_owb_tile_blocks : ow_blocks_per_call, cur_l_pad,
2021 cur_r_pad);
2022 zp_offset += accum_with_upper_bound(
2023 ow_step_size, cur_l_pad, cur_r_pad);
2024 cur_l_pad = nstl::max(0, cur_l_pad - ow_step_size);
2025 cur_r_pad = nstl::max(0, cur_r_pad - ow_step_size);
2026 if (!last) jmp(w_blk_end_label, T_NEAR);
2027 }
2028 L(w_blk_end_label);
2029
2030 } else {
2031 assert(jcp.oh_per_tile == 1);
2032 Label label_done;
2033 int ow_blocks_per_call = utils::div_up(jcp.ow_block, jcp.tile_width);
2034 int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call;
2035 if (last_owb_tile_blocks == 0 && jcp.tile_tail > 0)
2036 last_owb_tile_blocks = ow_blocks_per_call;
2037 if (last_owb_tile_blocks > 0) {
2038 Label label_not_last_owb;
2039 mov(reg_tmp, ptr[param1 + GET_OFF(owb)]);
2040 cmp(reg_tmp, jcp.nb_ow - 1);
2041 jne(label_not_last_owb, T_NEAR);
2042
2043 compute_ow_loop_body(true, last_owb_tile_blocks, 0, jcp.ow);
2044
2045 jmp(label_done, T_NEAR);
2046
2047 L(label_not_last_owb);
2048 }
2049 compute_ow_loop_body(false, ow_blocks_per_call, 0, jcp.ow);
2050
2051 L(label_done);
2052 }
2053}
2054
2055void jit_avx512_core_amx_fwd_kernel_t::generate() {
2056 preamble();
2057
2058 mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
2059 mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]);
2060 mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
2061 mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
2062 if (jcp.req_zero_point_buffer)
2063 mov(reg_zero_point_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
2064
2065 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
2066 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
2067
2068 const int fac = jcp.is_relo ? jcp.stride_w * jcp.kh
2069 : jcp.is_pbuffer_strided ? 1 : jcp.stride_w;
2070 const int inp_stride = fac * jcp.ic_block_int_np * jcp.typesize_in;
2071 const int wei_stride = jcp.oc_block * jcp.typesize_acc;
2072 mov(reg_inp_stride, inp_stride);
2073 mov(reg_wei_stride, wei_stride);
2074
2075 if (jcp.is_nspc && jcp.oc_without_padding != jcp.oc) {
2076 // Use mask 0xF by default for all output data and post-ops
2077 // loads / stores with block index
2078 // ocb = occ * jcp.nb_oc_blocking + (jcp.nb_oc_blocking - 1)
2079 // TODO: use masked loads / stores for the last occ only
2080 int current_block_size = jcp.oc_block;
2081 int mask = (1 << current_block_size) - 1;
2082 Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
2083 mov(regw_tmp, mask);
2084 kmovw(ktail_mask, regw_tmp);
2085 Xbyak::Label mask_is_set;
2086 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
2087 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
2088 jne(mask_is_set, T_NEAR);
2089 // Reset the mask
2090 current_block_size = jcp.oc_without_padding % jcp.oc_block;
2091 mask = (1 << current_block_size) - 1;
2092 mov(regw_tmp, mask);
2093 kmovw(ktail_mask, regw_tmp);
2094
2095 L(mask_is_set);
2096 }
2097 compute_ow_loop();
2098
2099 postamble();
2100
2101 if (jcp.with_eltwise) postops_injector_->prepare_table();
2102}
2103
2104void jit_avx512_core_amx_fwd_kernel_t::tile_configure(char *tcfg_buff) {
2105 const int vnni_width = jcp.src_dt == data_type::bf16 ? 2 : 4;
2106 // Input tile dimensions
2107 const int a_col = jcp.is_relo ? jcp.ic_block_int
2108 : jcp.ic_block_int_np * jcp.kw_per_tile;
2109 // Weights tile dimensions
2110 const int b_col = jcp.oc_block * vnni_width;
2111 const int b_row = a_col / vnni_width;
2112 // Accumulator tile dimensions
2113 const int c_col = 16;
2114
2115 for (size_t i = 0; i < 64; i++)
2116 tcfg_buff[i] = 0;
2117
2118 // Weights (W_BASE) Tensor Tiles
2119 for (int i = 0; i < jcp.nb_oc_blocking; i++)
2120 tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
2121 b_row, b_col * jcp.typesize_in);
2122
2123 // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
2124 for (int h = 0; h < jcp.nb_oh_blocking; h++) {
2125 tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
2126 jcp.tile_width, a_col * jcp.typesize_in);
2127 for (int i = 0; i < jcp.nb_oc_blocking; i++)
2128 tc_configure_tile((palette_config_t *)tcfg_buff,
2129 get_out_tensor(h, i), jcp.tile_width,
2130 c_col * jcp.typesize_acc);
2131 }
2132 if (jcp.tile_tail != 0) {
2133 assert(jcp.nb_oh_blocking == 1);
2134 assert(jcp.oh_per_tile == 1);
2135 assert(jcp.ow > jcp.tile_width);
2136 tc_configure_tile((palette_config_t *)tcfg_buff,
2137 get_inp_tensor(0, true), jcp.tile_tail,
2138 a_col * jcp.typesize_in);
2139 for (int i = 0; i < jcp.nb_oc_blocking; i++)
2140 tc_configure_tile((palette_config_t *)tcfg_buff,
2141 get_out_tensor(0, i, true), jcp.tile_tail,
2142 c_col * jcp.typesize_acc);
2143 }
2144
2145 ((palette_config_t *)tcfg_buff)->palette_id = amx::get_target_palette();
2146}
2147
2148void jit_avx512_core_amx_fwd_kernel_t::set_oh_blk_limits(jit_conv_conf_t &jcp) {
2149
2150 constexpr int size = sizeof(jcp.h_blk_limits) / sizeof(jcp.h_blk_limits[0]);
2151 // set default values
2152 for (int i = 0; i < size; i++)
2153 jcp.h_blk_limits[i] = jcp.oh;
2154
2155 const bool calculate_oh_limits
2156 = jcp.t_pad_output > 0 || jcp.b_pad_output > 0;
2157 if (jcp.req_zero_point_buffer && calculate_oh_limits) {
2158
2159 int limit_idx = 0;
2160 const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2161
2162 // full t_pad output block
2163 const int t_pad_blk_end = rnd_dn(jcp.t_pad_output, oh_step_size);
2164 if (jcp.t_pad_output >= oh_step_size) {
2165 jcp.h_blk_limits[limit_idx++] = t_pad_blk_end;
2166 }
2167 // t_pad output overlap with no padding
2168 const int t_pad_shift = jcp.t_pad_output % oh_step_size;
2169 if (t_pad_shift != 0) {
2170 jcp.h_blk_limits[limit_idx++] = t_pad_blk_end + t_pad_shift;
2171 }
2172 const int t_pad_next_blk = rnd_up(jcp.t_pad_output, oh_step_size);
2173 const int oh_blk_tail = jcp.oh % oh_step_size;
2174 const int b_pad_no_tail = nstl::max(0, jcp.b_pad_output - oh_blk_tail);
2175 const int b_pad_start
2176 = nstl::max(jcp.t_pad_output, jcp.oh - jcp.b_pad_output);
2177 const int b_pad_blk_start = rnd_dn(b_pad_start, oh_step_size);
2178 // middle block without padding
2179 const int mid_blk = nstl::max(0, b_pad_blk_start - t_pad_next_blk);
2180 if (mid_blk >= oh_step_size) {
2181 jcp.h_blk_limits[limit_idx++] = b_pad_blk_start;
2182 }
2183 // no padding with b_pad overlap
2184 const int b_pad_shift = b_pad_no_tail % oh_step_size;
2185 if (b_pad_shift != 0) {
2186 jcp.h_blk_limits[limit_idx++] = rnd_up(b_pad_start, oh_step_size);
2187 }
2188 // full b_pad output block
2189 if (b_pad_no_tail >= oh_step_size) {
2190 jcp.h_blk_limits[limit_idx++] = jcp.oh - oh_blk_tail;
2191 }
2192 // b_pad tail block does not require a limit
2193 }
2194}
2195
2196void jit_avx512_core_amx_fwd_kernel_t::set_ow_blk_limits(jit_conv_conf_t &jcp) {
2197
2198 jcp.l_pad_blk = 0;
2199 jcp.no_pad_w_blk = 0;
2200 jcp.r_pad_blk = 0;
2201
2202 const bool calculate_ow_limits
2203 = jcp.nb_ow > 1 && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0);
2204 if (jcp.req_zero_point_buffer && calculate_ow_limits) {
2205 const int ow_step_size = jcp.ow_block;
2206
2207 // l_pad
2208 const int l_pad_limit
2209 = (jcp.l_pad_output >= ow_step_size ? ow_step_size : 0)
2210 + (jcp.l_pad_output % ow_step_size);
2211 const int l_pad_area_blk = rnd_up(l_pad_limit, ow_step_size);
2212 jcp.l_pad_blk = div_up(l_pad_limit, ow_step_size);
2213
2214 // middle (area without padding)
2215 const int no_pad_area
2216 = nstl::max(0, jcp.ow - l_pad_area_blk - jcp.r_pad_output);
2217 jcp.no_pad_w_blk = no_pad_area >= ow_step_size ? 1 : 0;
2218
2219 // r_pad
2220 const int no_pad_area_shift = no_pad_area % ow_step_size;
2221 const int r_pad_area_overlap
2222 = no_pad_area_shift == 0 ? 0 : ow_step_size - no_pad_area_shift;
2223 const int r_pad_area
2224 = nstl::max(0, jcp.r_pad_output - r_pad_area_overlap);
2225 const int r_pad_limit = (r_pad_area >= ow_step_size ? ow_step_size : 0)
2226 + (r_pad_area % ow_step_size);
2227 jcp.r_pad_blk = (r_pad_area_overlap > 0 ? 1 : 0)
2228 + div_up(r_pad_limit, ow_step_size);
2229 }
2230}
2231
2232status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp,
2233 const convolution_desc_t &cd, memory_desc_t &src_md,
2234 memory_desc_t &weights_md, memory_desc_t &dst_md,
2235 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
2236 using namespace prop_kind;
2237
2238 const memory_desc_wrapper src_d(&src_md);
2239 const memory_desc_wrapper weights_d(&weights_md);
2240 const memory_desc_wrapper dst_d(&dst_md);
2241 const memory_desc_wrapper bias_d(&bias_md);
2242
2243 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
2244 int ndims = src_d.ndims();
2245 bool is_1d = ndims == 3;
2246 bool is_3d = ndims == 5;
2247
2248 const bool is_bf16_convolution
2249 = everyone_is(true, src_d.data_type() == data_type::bf16,
2250 weights_d.data_type() == data_type::bf16,
2251 one_of(dst_d.data_type(), data_type::bf16, data_type::f32));
2252 const bool is_int8_convolution = everyone_is(true,
2253 (src_d.data_type() == data_type::u8
2254 || src_d.data_type() == data_type::s8),
2255 weights_d.data_type() == data_type::s8,
2256 one_of(dst_d.data_type(), data_type::f32, data_type::s32,
2257 data_type::s8, data_type::u8, data_type::bf16));
2258
2259 bool supported = mayiuse(avx512_core_amx)
2260 && (is_bf16_convolution || is_int8_convolution);
2261 if (!supported) return status::unimplemented;
2262
2263 jcp = zero<decltype(jcp)>();
2264 jcp.isa = avx512_core_amx;
2265 jcp.ndims = ndims;
2266 jcp.prop_kind = cd.prop_kind;
2267 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
2268
2269 jcp.mb = src_d.dims()[0];
2270 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
2271 jcp.oc_without_padding = jcp.oc;
2272 jcp.ic = src_d.dims()[1] / jcp.ngroups;
2273 jcp.ic_without_padding = jcp.ic;
2274 jcp.id = is_3d ? src_d.dims()[2] : 1;
2275 jcp.ih = !is_1d ? src_d.dims()[ndims - 2] : 1;
2276 jcp.iw = src_d.dims()[ndims - 1];
2277 jcp.od = is_3d ? dst_d.dims()[2] : 1;
2278 jcp.oh = !is_1d ? dst_d.dims()[ndims - 2] : 1;
2279 jcp.ow = dst_d.dims()[ndims - 1];
2280 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
2281 jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
2282 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
2283 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
2284 jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
2285 jcp.l_pad = cd.padding[0][ndims - 3];
2286 jcp.stride_d = is_3d ? cd.strides[0] : 1;
2287 jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
2288 jcp.stride_w = cd.strides[ndims - 3];
2289 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
2290
2291 jcp.dilate_d = is_3d ? cd.dilates[ndims - 5] : 0;
2292 jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
2293 jcp.dilate_w = cd.dilates[ndims - 3];
2294
2295 const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1;
2296 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
2297 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
2298 jcp.back_pad = calculate_end_padding(
2299 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, gen_kd);
2300 jcp.b_pad = calculate_end_padding(
2301 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
2302 jcp.r_pad = calculate_end_padding(
2303 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
2304 if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
2305 || jcp.b_pad >= gen_kh || jcp.f_pad >= gen_kd
2306 || jcp.back_pad >= gen_kd)
2307 return status::unimplemented;
2308
2309 const int max_pad = 28; // akin to maximum jcp.ur_w value in other jits
2310 if (jcp.l_pad > max_pad || jcp.r_pad > max_pad)
2311 return status::unimplemented; // TODO: relax this restriction
2312
2313 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
2314 jcp.dst_dt = cd.dst_desc.data_type;
2315 jcp.src_dt = cd.src_desc.data_type;
2316 jcp.wei_dt = cd.weights_desc.data_type;
2317
2318 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
2319
2320 if (jcp.is_depthwise)
2321 return status::unimplemented; // TODO: add support of DW convolution
2322
2323 // Dispatch small shapes to VNNI for better performance
2324 const auto is_real_3d = (jcp.ndims == 5
2325 && (jcp.id > 1 || jcp.od > 1 || jcp.kd > 1 || jcp.dilate_d > 0));
2326 const auto is_supported_small_ic = jcp.ic <= 4 && !is_real_3d;
2327 const auto is_small_shape = (jcp.od * jcp.oh * jcp.ow <= 4) && jcp.ic <= 512
2328 && jcp.mb * jcp.ngroups * jcp.ic * jcp.oc <= static_cast<int32_t>(
2329 platform::get_per_core_cache_size(1));
2330 const auto is_3d_small_ic = is_real_3d && jcp.ic * jcp.oc <= 32
2331 && jcp.od >= 128 && jcp.oh >= 128 && jcp.ow >= 128;
2332 if ((is_small_shape || is_3d_small_ic) && !is_supported_small_ic)
2333 return status::unimplemented;
2334
2335 const auto zp = attr.zero_points_;
2336 jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
2337 jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
2338 jcp.zp_src_is_common = zp.common(
2339 DNNL_ARG_SRC); // otherwise, it's per-channel (not supported)
2340
2341 if (!IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)
2342 || !IMPLICATION(jcp.dst_zero_point || jcp.src_zero_point,
2343 is_int8_convolution))
2344 return status::unimplemented;
2345
2346 // Calculate zero-point padding values outside of the main JIT-kernel
2347 // and store the results in an auxiliary buffer.
2348 jcp.req_zero_point_buffer = jcp.src_zero_point
2349 && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0 || jcp.t_pad > 0
2350 || jcp.f_pad > 0 || jcp.back_pad > 0);
2351
2352 format_tag_t dat_tag_ncsp = utils::pick(ndims - 3, format_tag::nCw16c,
2353 format_tag::nChw16c, format_tag::nCdhw16c);
2354 format_tag_t dat_tag_nspc = utils::pick(
2355 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
2356 // To toggle the default data layout for BF16 between nChw16c and nhwc,
2357 // swap the following two variable definitions. Current choice: nhwc.
2358
2359 // Clang-tidy change - if it was intentional please revert it and
2360 // put `NOLINTNEXTLINE` to suppress the warning.
2361 format_tag_t dat_tag_opt = dat_tag_nspc;
2362 format_tag_t dat_tag_alt
2363 = is_bf16_convolution ? dat_tag_ncsp : dat_tag_nspc;
2364
2365 if (src_d.format_kind() == format_kind::any) {
2366 CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
2367 jcp.src_tag = dat_tag_opt;
2368 } else
2369 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
2370
2371 if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
2372 return status::unimplemented;
2373
2374 jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
2375 assert(IMPLICATION(is_int8_convolution, jcp.is_nspc));
2376
2377 // TODO: remove all support for nChw16c from this implementation
2378 if (!jcp.is_nspc) return status::unimplemented;
2379
2380 if (dst_d.format_kind() == format_kind::any) {
2381 CHECK(memory_desc_init_by_tag(dst_md, jcp.src_tag));
2382 jcp.dst_tag = jcp.src_tag;
2383 } else
2384 jcp.dst_tag = dst_d.matches_one_of_tag(jcp.src_tag);
2385
2386 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
2387
2388 if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
2389 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
2390
2391 jcp.nthr = nthreads;
2392
2393 jcp.ic_block = 16;
2394 jcp.oc_block = 16;
2395
2396 const auto ic_unrounded = jcp.ic;
2397 if (jcp.ngroups == 1) {
2398 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
2399 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
2400 }
2401 bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
2402 if (!args_ok) return status::unimplemented;
2403
2404 const int vnni_width = is_bf16_convolution ? 2 : 4;
2405 jcp.ic_block_int = jcp.ic_block * vnni_width; // 32 for bf16, 64 for int8
2406
2407 // fallback to non-amx impl when accumulation is too small
2408 const dim_t total_k = jcp.ic_without_padding * jcp.kd * jcp.kh * jcp.kw;
2409 const bool is_tiny_k = total_k < jcp.ic_block_int / 2;
2410 if (is_tiny_k) return status::unimplemented;
2411
2412 // small-ic parameters
2413 jcp.ic_block_int_np = jcp.is_nspc
2414 ? nstl::min(jcp.ic_block_int, jcp.ic_without_padding)
2415 : jcp.ic_block_int;
2416 bool is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2417
2418 // reduced lowering
2419 const bool is_trivial_3d = everyone_is(1, jcp.id, jcp.od, jcp.kd);
2420 jcp.is_relo = is_trivial_3d
2421 && is_small_ic
2422 // no trivial cases
2423 && 1 < jcp.kh * jcp.kw
2424 // reduction dimension size (heuristic)
2425 && IMPLICATION(is_int8_convolution,
2426 12 * 3 * 3 < ic_unrounded * jcp.kh * jcp.kw)
2427 // required for use of VPERMB instruction in weights copy kernel
2428 && IMPLICATION(is_int8_convolution,
2429 cpu().has(Xbyak::util::Cpu::tAVX512_VBMI))
2430 // no dilation or excessive stride along w-direction
2431 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
2432 // no dilation or excessive stride along h-direction
2433 && jcp.stride_h <= jcp.kh && jcp.stride_w <= jcp.kw;
2434
2435 // Dispatch specific small ic shapes to VNNI for better performance
2436 const auto is_2d_small_ic = jcp.ndims == 4
2437 && (ic_unrounded < 12
2438 || (ic_unrounded >= 12 && ic_unrounded < 16
2439 && jcp.ow * jcp.oh < 768 * 768));
2440 if (is_int8_convolution && is_2d_small_ic && !jcp.is_relo)
2441 return status::unimplemented;
2442
2443 jcp.nreduce = jcp.kh * jcp.kw * jcp.ic_block_int_np;
2444
2445 if (!jcp.is_relo) {
2446 jcp.ic_block_int_np = is_bf16_convolution
2447 ? jcp.ic_block_int
2448 : rnd_up(jcp.ic_block_int_np, vnni_width);
2449 is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2450 }
2451 // k-remainders
2452 jcp.kw_per_tile = is_small_ic && !jcp.is_relo && jcp.dilate_w == 0
2453 && jcp.stride_w <= jcp.kw // TODO: relax this restriction
2454 && jcp.kw * jcp.ic_block_int_np <= jcp.ic_block_int
2455 ? jcp.kw
2456 : 1;
2457 jcp.is_pbuffer_strided = (1 == jcp.kw_per_tile);
2458 jcp.n_stride_sets
2459 = jcp.is_pbuffer_strided ? nstl::min(jcp.stride_w, jcp.kw) : 1;
2460 jcp.kw_step = jcp.is_pbuffer_strided ? jcp.stride_w : jcp.kw_per_tile;
2461
2462 if (attr.set_default_formats(&dst_md) != status::success)
2463 return status::unimplemented;
2464
2465 const auto &p = attr.post_ops_;
2466
2467 const int sum_ind = p.find(primitive_kind::sum);
2468 jcp.with_sum = sum_ind != -1;
2469 const int eltwise_ind = p.find(primitive_kind::eltwise);
2470 jcp.with_eltwise = eltwise_ind != -1;
2471 const int binary_ind = p.find(primitive_kind::binary);
2472 jcp.with_binary = binary_ind != -1;
2473 jcp.sum_dt = p.get_sum_dt(jcp.dst_dt);
2474
2475 jcp.post_ops = p;
2476
2477 using namespace injector;
2478 const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16);
2479 const bool sum_requires_scale_one = sum_at_pos_0_only;
2480 const bool sum_requires_zp_zero = sum_at_pos_0_only;
2481 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
2482 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
2483 sum_requires_zp_zero});
2484 if (!post_ops_ok_) return status::unimplemented;
2485
2486 auto set_or_check_wei_format = [&]() {
2487 using namespace format_tag;
2488 using namespace memory_extra_flags;
2489 format_tag_t wei_tag;
2490 wei_tag = jcp.is_relo ? pick(with_groups + 2 * (ndims - 3), Owi16o,
2491 gOwi16o, Owhi16o, gOwhi16o, Odwhi16o, gOdwhi16o)
2492 : is_bf16_convolution
2493 ? pick(with_groups + 2 * (ndims - 3), OIw16i16o2i,
2494 gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i,
2495 OIdhw16i16o2i, gOIdhw16i16o2i)
2496 : is_small_ic ? pick(with_groups + 2 * (ndims - 3),
2497 OwI16o4i, gOwI16o4i, OhwI16o4i, gOhwI16o4i,
2498 OdhwI16o4i, gOdhwI16o4i)
2499 : pick(with_groups + 2 * (ndims - 3),
2500 OIw16i16o4i, gOIw16i16o4i,
2501 OIhw16i16o4i, gOIhw16i16o4i,
2502 OIdhw16i16o4i, gOIdhw16i16o4i);
2503
2504 memory_desc_t want_wei_md = weights_md;
2505 memory_desc_init_by_tag(want_wei_md, wei_tag);
2506
2507 if (jcp.src_zero_point) {
2508 want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
2509 want_wei_md.extra.asymm_compensation_mask = (1 << 0)
2510 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
2511 }
2512 if (weights_md.format_kind == format_kind::any) {
2513 weights_md = want_wei_md;
2514 return true;
2515 }
2516 return weights_md == want_wei_md;
2517 };
2518
2519 if (!set_or_check_wei_format()) return status::unimplemented;
2520
2521 jcp.typesize_in = types::data_type_size(src_d.data_type());
2522 jcp.typesize_out = types::data_type_size(dst_d.data_type());
2523 jcp.typesize_bia
2524 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
2525 jcp.typesize_acc = sizeof(int32_t);
2526
2527 jcp.nb_ic = jcp.ic / jcp.ic_block;
2528 jcp.nb_oc = jcp.oc / jcp.oc_block;
2529 jcp.nb_ic_int = div_up(jcp.ic, jcp.ic_block_int);
2530
2531 jcp.nb_oc_blocking_thr_chunk = 1;
2532
2533 const int target_palette = amx::get_target_palette();
2534 jcp.max_tiles = amx::get_max_tiles(target_palette);
2535 jcp.full_tile_width = amx::get_max_rows(target_palette);
2536 if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
2537 return status::unimplemented;
2538
2539 // Pack n rows per tile, such that:
2540 // ow + (ow + gen_kw - 1) * (n - 1) <= jcp.full_tile_width
2541 auto calculate_tile_width = [&](int n) {
2542 assert(n > 0);
2543 return jcp.ow + (gen_kw + jcp.ow - 1) * (n - 1);
2544 };
2545 const bool ok_to_pack_tile = !jcp.is_relo
2546 && (utils::everyone_is(1, jcp.kh, jcp.kw)
2547 || utils::everyone_is(1, jcp.stride_h, jcp.stride_w));
2548 const int max_oh_per_tile
2549 = 1 + (jcp.full_tile_width - jcp.ow) / (jcp.ow + gen_kw - 1);
2550 jcp.oh_per_tile = ok_to_pack_tile
2551 ? nstl::min(jcp.oh, nstl::max(1, max_oh_per_tile))
2552 : 1;
2553 jcp.tile_width = nstl::min<int>(
2554 jcp.full_tile_width, calculate_tile_width(jcp.oh_per_tile));
2555 jcp.ow_blocks = utils::div_up(jcp.ow, jcp.tile_width);
2556
2557 // Prefer to use a single tile width when possible
2558 // (eg ow28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
2559 if (jcp.oh_per_tile == 1 && jcp.ow % jcp.ow_blocks == 0)
2560 jcp.tile_width = jcp.ow / jcp.ow_blocks;
2561 jcp.tile_tail = jcp.oh_per_tile == 1 ? jcp.ow % jcp.tile_width : 0;
2562
2563 jcp.nb_oc_blocking = (jcp.nb_oc % 2 == 0) ? 2 : 1;
2564 jcp.nb_ic_blocking = 1;
2565 jcp.nb_oh_blocking
2566 = utils::everyone_is(true, jcp.tile_tail == 0,
2567 // requirement for interleave stores
2568 IMPLICATION(jcp.ow_blocks > 1, jcp.oh % 2 == 0),
2569 // requirement for small spatial
2570 utils::div_up(jcp.oh, jcp.oh_per_tile) > 1,
2571 // choose maximal pbuffer overlap for reduced lowering
2572 !jcp.is_relo)
2573 ? 2
2574 : 1;
2575
2576 // TODO: tune oh blocking
2577 const int oh_blk_size_param = jcp.is_relo ? 1 : 10;
2578 const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2579 const int oh_blk_size = rnd_up(oh_blk_size_param, oh_step_size);
2580 jcp.oh_blk_size = rnd_up(nstl::min(jcp.oh, oh_blk_size), oh_step_size);
2581 // Here ihp means the input buffer height including padding (ie the number
2582 // of input rows required for computation of jcp.oh_blk_size output rows.
2583 // If an input row doesn't participate in the computation of any output row,
2584 // it isn't copied to the buffer at all (eg jcp.stride_h > gen_kh).
2585 jcp.ihp = jcp.is_relo
2586 ? jcp.oh_blk_size
2587 : (jcp.oh_blk_size - 1) * nstl::min(jcp.stride_h, gen_kh) + gen_kh;
2588
2589 // TODO: tune ow blocking
2590 const int ow_blocks_per_call = jcp.is_relo ? 10 : 2;
2591 jcp.ow_block = nstl::min(jcp.ow, jcp.tile_width * ow_blocks_per_call);
2592 jcp.nb_ow = utils::div_up(jcp.ow, jcp.ow_block);
2593 // iwp includes all width elements that are really used in calculation
2594 // including left and right zero padding
2595 const bool are_sets_interleaved
2596 = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
2597 jcp.iwp = are_sets_interleaved
2598 ? (jcp.ow_block - 1) * nstl::min(jcp.stride_w, jcp.kw) + gen_kw
2599 : jcp.ow_block * jcp.kw;
2600
2601 // Number of ops per tile store
2602 int ops_tile_store = jcp.tile_width;
2603 // Number of ops per accumulation tile
2604 int avaliable_ops = jcp.is_relo
2605 ? utils::div_up(jcp.nreduce, jcp.ic_block_int)
2606 : jcp.nb_ic_int * jcp.kh * (jcp.kw / jcp.kw_per_tile);
2607 // Number of vectors to store per tile operation
2608 // NOTE: set to zero to turn off interleave store (mostly for debugging)
2609 jcp.per_one_pstore = utils::div_up(ops_tile_store, avaliable_ops);
2610
2611 if (jcp.is_relo) {
2612 jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.ihp * jcp.iwp * jcp.kh
2613 * jcp.ic_block_int_np
2614 // pbuffer pointer shifts each oh step for reduced-lowering
2615 + (jcp.oh - 1) * jcp.stride_h * jcp.ic_block_int_np
2616 // extra $line due to pbuffer writing full Zmm
2617 + jcp.ic_block_int;
2618 } else {
2619 jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.kd
2620 * ((size_t)jcp.ihp * jcp.iwp * jcp.ic_block_int_np
2621 // extra $line due to pbuffer writing full Zmm
2622 + jcp.ic_block_int);
2623 }
2624 jcp.wei_buffer_size = (size_t)jcp.ngroups * jcp.nb_oc
2625 * rnd_up(jcp.kh * jcp.kw * jcp.ic * jcp.oc_block, 1024);
2626 jcp.wsp_buffer_size = (size_t)jcp.nb_oh_blocking * jcp.nb_oc_blocking
2627 * jcp.full_tile_width * jcp.oc_block;
2628
2629 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
2630 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
2631 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
2632 const int wei_mask_per_oc = 1 << (int)with_groups;
2633 jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc;
2634 jcp.dst_scale = !dst_scales.has_default_values();
2635
2636 // only common src & dst scales are supported
2637 // only common and per-oc-channel weight scales are supported
2638 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc)
2639 && everyone_is(src_scales.mask_, dst_scales.mask_, 0);
2640 if (!scales_ok) return status::unimplemented;
2641
2642 // Note: currently unsupported, results in seg-fault
2643 const int l_pad_output = nstl::min(jcp.ow, div_up(jcp.l_pad, jcp.stride_w));
2644 if (!jcp.is_relo && (l_pad_output > jcp.ow_block))
2645 return status::unimplemented;
2646
2647 // Relevant to 'zero_point padding buffer' (pbuff) jit kernel
2648 if (jcp.req_zero_point_buffer) {
2649 auto calculate_output_padding_dims = [=](int o_dim, int s_pad,
2650 int e_pad,
2651 int &s_pad_output,
2652 int &e_pad_output,
2653 bool &o_mid, int &o_pad,
2654 int stride,
2655 bool req_mid_area) {
2656 s_pad_output = nstl::min(o_dim, div_up(s_pad, stride));
2657 e_pad_output = nstl::min(o_dim, div_up(e_pad, stride));
2658 o_mid = (o_dim - s_pad_output - e_pad_output > 0) && req_mid_area;
2659 o_pad = nstl::min(o_dim,
2660 nstl::max(1, s_pad_output + e_pad_output + (int)o_mid));
2661 };
2662
2663 const bool mid_w_area = (jcp.l_pad > 0 || jcp.r_pad > 0)
2664 && (jcp.t_pad > 0 || jcp.b_pad > 0 || jcp.f_pad > 0
2665 || jcp.back_pad > 0);
2666 const bool mid_h_area = (jcp.t_pad > 0 || jcp.b_pad > 0)
2667 && (jcp.l_pad > 0 || jcp.r_pad > 0 || jcp.f_pad > 0
2668 || jcp.back_pad > 0);
2669 const bool mid_d_area = (jcp.f_pad > 0 || jcp.back_pad > 0)
2670 && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0
2671 || jcp.t_pad > 0);
2672 calculate_output_padding_dims(jcp.ow, jcp.l_pad, jcp.r_pad,
2673 jcp.l_pad_output, jcp.r_pad_output, jcp.ow_mid, jcp.ow_pad,
2674 jcp.stride_w, mid_w_area);
2675 calculate_output_padding_dims(jcp.oh, jcp.t_pad, jcp.b_pad,
2676 jcp.t_pad_output, jcp.b_pad_output, jcp.oh_mid, jcp.oh_pad,
2677 jcp.stride_h, mid_h_area);
2678 calculate_output_padding_dims(jcp.od, jcp.f_pad, jcp.back_pad,
2679 jcp.f_pad_output, jcp.back_pad_output, jcp.od_mid, jcp.od_pad,
2680 jcp.stride_d, mid_d_area);
2681 jcp.zp_pbuff_size
2682 = jcp.od_pad * jcp.oh_pad * jcp.ow_pad * jcp.oc * jcp.ngroups;
2683
2684 // compute zero-point padding kernel outside of the main parallel
2685 // region when threads are more likely to parallelize work across mb
2686 // within the convolution compute block.
2687 jcp.zp_pbuff_outer_compute = jcp.mb > 1 || is_3d;
2688
2689 const bool params_ok = ((jcp.ow_pad - (int)jcp.ow_mid) <= max_pad * 2);
2690 if (!params_ok) { return status::unimplemented; }
2691 }
2692
2693 // Set default parameters for driver code, but mostly required for
2694 // 'zero_point padding buffer' (pbuff) accumulation over output tensor
2695 set_oh_blk_limits(jcp);
2696 set_ow_blk_limits(jcp);
2697
2698 return status::success;
2699}
2700
2701status_t jit_avx512_core_amx_fwd_kernel_t::init_scratchpad(
2702 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
2703 const primitive_attr_t &attr) {
2704
2705 size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
2706 scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
2707 if (jcp.is_relo) {
2708 scratchpad.book(
2709 key_conv_amx_wei_buffer, jcp.wei_buffer_size, jcp.typesize_in);
2710 }
2711
2712 size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
2713 scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
2714 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) {
2715 assert(jcp.ngroups == 1);
2716 scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
2717 }
2718 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
2719 if (jcp.req_zero_point_buffer) {
2720 const int nthr = jcp.zp_pbuff_outer_compute ? 1 : jcp.nthr;
2721 scratchpad.book(key_conv_zero_point_pad,
2722 (size_t)nthr * jcp.zp_pbuff_size, sizeof(int32_t));
2723 if (!jcp.zp_pbuff_outer_compute) {
2724 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
2725 scratchpad.book<bool>(key_conv_zero_point_flag,
2726 (size_t)jcp.nthr * oc_chunks * jcp.ngroups);
2727 }
2728 }
2729
2730 book_precomputed_scales(
2731 scratchpad, attr.scales_, jcp.ngroups * jcp.oc_without_padding);
2732
2733 // Keep scratchpad memory footprint under control
2734 const size_t L2_size_per_core = platform::get_per_core_cache_size(2);
2735 const size_t L3_size_per_core = platform::get_per_core_cache_size(3);
2736 const size_t max_scratchpad_size
2737 = jcp.nthr * (L2_size_per_core + L3_size_per_core);
2738 // TODO: tune this relationship as needed
2739 if (scratchpad.size() > max_scratchpad_size) return status::unimplemented;
2740 return status::success;
2741}
2742
2743void jit_avx512_core_amx_bwd_data_copy_kernel_t::copy_row(
2744 const bool is_masked) {
2745 assert(jcp.is_nspc && "no support for nChw16c in this copy kernel");
2746
2747 const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
2748 const int inp_w_step
2749 = jcp.ngroups * jcp.oc_without_padding * jcp.typesize_in;
2750 const int inp_h_step = jcp.ow * inp_w_step;
2751 const int out_w_step = jcp.oc_block_int * jcp.typesize_in;
2752 const int out_h_step = jcp.owp * out_w_step;
2753
2754 auto zero_it = [=](reg64_t tmp_out_ptr, int offset) {
2755 // no mask as output is a padded buffer
2756 if (is_bf16)
2757 vmovdqu16(ptr[tmp_out_ptr + offset], zmm_zero);
2758 else
2759 vmovdqu8(ptr[tmp_out_ptr + offset], zmm_zero);
2760 };
2761
2762 auto copy_it = [=](reg64_t tmp_inp_ptr, int inp_off, reg64_t tmp_out_ptr,
2763 int out_off) {
2764 Zmm zmm_load = is_masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
2765 Zmm zmm_stor = zmm_tmp; // no mask as output is padded buffer
2766 if (is_bf16) {
2767 vmovdqu16(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2768 vmovdqu16(ptr[tmp_out_ptr + out_off], zmm_stor);
2769 } else {
2770 vmovdqu8(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2771 vmovdqu8(ptr[tmp_out_ptr + out_off], zmm_stor);
2772 }
2773 };
2774
2775 { // Handle Top Overflow
2776 Label label_tov_loop, label_tov_skip;
2777 // number of zero-padded rows above src buffer to copy
2778 mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
2779 test(reg_tov, reg_tov);
2780 jz(label_tov_skip, T_NEAR);
2781 L(label_tov_loop);
2782 {
2783 for (int ow = 0; ow < jcp.owp; ow++) {
2784 const int offset = ow * out_w_step;
2785 zero_it(reg_ptr_aux_out, offset);
2786 }
2787 add(reg_ptr_aux_out, out_h_step);
2788 dec(reg_tov);
2789 jnz(label_tov_loop, T_NEAR);
2790 }
2791 L(label_tov_skip);
2792 }
2793
2794 // Handle Middle Loop
2795 Label label_khp_loop, label_khp_skip;
2796 test(reg_khp, reg_khp);
2797 jz(label_khp_skip, T_NEAR);
2798 mov(reg_cnt_khp, reg_khp);
2799 L(label_khp_loop);
2800 {
2801 Label label_lov, label_lov_skip;
2802 Label label_kwp, label_kwp_skip;
2803 Label label_rov, label_rov_skip;
2804 test(reg_lov, reg_lov);
2805 jnz(label_lov, T_NEAR);
2806 test(reg_kwp, reg_kwp);
2807 jnz(label_kwp, T_NEAR);
2808 test(reg_rov, reg_rov);
2809 jnz(label_rov, T_NEAR);
2810
2811 test(reg_lov, reg_lov);
2812 jz(label_lov_skip, T_NEAR); // not really needed, but just to be safe
2813 L(label_lov); // Handle Left Overflow
2814 {
2815 Label label_lov_loop;
2816 mov(reg_cnt_tmp, reg_lov);
2817 L(label_lov_loop);
2818 {
2819 zero_it(reg_ptr_aux_out, 0);
2820 add(reg_ptr_aux_out, out_w_step);
2821 dec(reg_cnt_tmp);
2822 jnz(label_lov_loop, T_NEAR);
2823 }
2824 }
2825 L(label_lov_skip);
2826
2827 test(reg_kwp, reg_kwp);
2828 jz(label_kwp_skip, T_NEAR);
2829 L(label_kwp); // Handle Center Loop
2830 {
2831 Label label_kwp_loop;
2832 mov(reg_ptr_aux_inp_w, reg_ptr_aux_inp_h);
2833 mov(reg_cnt_tmp, reg_kwp);
2834 L(label_kwp_loop);
2835 {
2836 copy_it(reg_ptr_aux_inp_w, 0, reg_ptr_aux_out, 0);
2837 add(reg_ptr_aux_out, out_w_step);
2838 add(reg_ptr_aux_inp_w, inp_w_step);
2839 dec(reg_cnt_tmp);
2840
2841 if (jcp.stride_w > 1) {
2842 jz(label_kwp_skip, T_NEAR);
2843 // Handle Dilation-by-Stride
2844 for (int sw = 0; sw < jcp.stride_w - 1; sw++) {
2845 const int offset = sw * out_w_step;
2846 zero_it(reg_ptr_aux_out, offset);
2847 }
2848 add(reg_ptr_aux_out, (jcp.stride_w - 1) * out_w_step);
2849 if (jcp.stride_w == 2)
2850 dec(reg_cnt_tmp);
2851 else
2852 sub(reg_cnt_tmp, jcp.stride_w - 1);
2853 jmp(label_kwp_loop, T_NEAR);
2854 } else {
2855 jnz(label_kwp_loop, T_NEAR);
2856 }
2857 }
2858 }
2859 L(label_kwp_skip);
2860
2861 test(reg_rov, reg_rov);
2862 jz(label_rov_skip, T_NEAR);
2863 L(label_rov); // Handle Right Overflow
2864 {
2865 Label label_rov_loop;
2866 mov(reg_cnt_tmp, reg_rov);
2867 L(label_rov_loop);
2868 {
2869 zero_it(reg_ptr_aux_out, 0);
2870 add(reg_ptr_aux_out, out_w_step);
2871 dec(reg_cnt_tmp);
2872 jnz(label_rov_loop, T_NEAR);
2873 }
2874 }
2875 L(label_rov_skip);
2876
2877 add(reg_ptr_aux_inp_h, inp_h_step);
2878 dec(reg_cnt_khp);
2879
2880 if (jcp.stride_h > 1) {
2881 jz(label_khp_skip, T_NEAR);
2882 // Handle Dilation-by-Stride
2883 for (int sh = 0; sh < jcp.stride_h - 1; sh++) {
2884 for (int ow = 0; ow < jcp.owp; ow++) {
2885 const int offset = sh * out_h_step + ow * out_w_step;
2886 zero_it(reg_ptr_aux_out, offset);
2887 }
2888 }
2889 add(reg_ptr_aux_out, (jcp.stride_h - 1) * out_h_step);
2890 if (jcp.stride_h == 2)
2891 dec(reg_cnt_khp);
2892 else
2893 sub(reg_cnt_khp, jcp.stride_h - 1);
2894 jmp(label_khp_loop, T_NEAR);
2895 } else {
2896 jnz(label_khp_loop, T_NEAR);
2897 }
2898 }
2899 L(label_khp_skip);
2900
2901 { // Handle Bottom Overflow
2902 Label label_bov_loop, label_bov_skip;
2903
2904 // number of zero-padded rows below src buffer to copy
2905 mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
2906 test(reg_bov, reg_bov);
2907 jz(label_bov_skip, T_NEAR);
2908 L(label_bov_loop);
2909 {
2910 for (int ow = 0; ow < jcp.owp; ow++) {
2911 const int offset = ow * out_w_step;
2912 zero_it(reg_ptr_aux_out, offset);
2913 }
2914 add(reg_ptr_aux_out, out_h_step);
2915 dec(reg_bov);
2916 jnz(label_bov_loop, T_NEAR);
2917 }
2918 L(label_bov_skip);
2919 }
2920}
2921
2922void jit_avx512_core_amx_bwd_data_copy_kernel_t::kd_loop(bool is_masked) {
2923
2924 Xbyak::Label kd_label, no_kd_label;
2925
2926 const bool is_3d = jcp.ndims == 5;
2927
2928 mov(reg_ptr_aux_out, reg_ptr_out);
2929 mov(reg_ptr_aux_inp_h, reg_ptr_inp);
2930
2931 if (is_3d) {
2932 mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
2933 cmp(reg_kd, 0);
2934 jle(no_kd_label, T_NEAR);
2935 L(kd_label);
2936 push(reg_ptr_aux_out);
2937 push(reg_ptr_aux_inp_h);
2938 }
2939
2940 copy_row(is_masked);
2941
2942 if (is_3d) {
2943 const size_t inp_d_offset = static_cast<size_t>(jcp.typesize_in)
2944 * (jcp.dilate_d + 1) * jcp.oh * jcp.ow * jcp.ngroups
2945 * jcp.oc_without_padding;
2946 const size_t out_d_offset = static_cast<size_t>(jcp.typesize_in)
2947 * jcp.ohp * jcp.owp * jcp.oc_block_int;
2948 pop(reg_ptr_aux_inp_h);
2949 pop(reg_ptr_aux_out);
2950 sub(reg_ptr_aux_inp_h, inp_d_offset);
2951 add(reg_ptr_aux_out, out_d_offset);
2952
2953 dec(reg_kd);
2954 jnz(kd_label, T_NEAR);
2955 L(no_kd_label);
2956 }
2957}
2958
2959void jit_avx512_core_amx_bwd_data_copy_kernel_t::generate() {
2960
2961 const int inp_c_step = jcp.oc_block_int * jcp.typesize_in;
2962 const int out_c_step = jcp.kd * jcp.ohp * jcp.owp * inp_c_step;
2963 const int nb_oc_int_no_tail = jcp.oc_without_padding / jcp.oc_block_int;
2964 const int oc_block_int_tail = jcp.oc_without_padding % jcp.oc_block_int;
2965
2966 preamble();
2967
2968 // pointer to 1st needed element in src buffer
2969 mov(reg_ptr_inp, ptr[param1 + GET_OFF(src)]);
2970 // pointer to 1st needed element in dst buffer
2971 mov(reg_ptr_out, ptr[param1 + GET_OFF(dst)]);
2972
2973 // number of rows of src buffer to copy
2974 mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
2975
2976 // number of columns of src buffer to copy
2977 mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
2978 // number of zero-padded columns before src buffer to copy
2979 mov(reg_lov, ptr[param1 + GET_OFF(l_overflow)]);
2980 // number of zero-padded columns before src buffer to copy
2981 mov(reg_rov, ptr[param1 + GET_OFF(r_overflow)]);
2982
2983 vpxord(zmm_zero, zmm_zero, zmm_zero);
2984
2985 if (oc_block_int_tail > 0) {
2986 uint64_t mask = (UINT64_C(1) << oc_block_int_tail) - 1;
2987 mov(reg_tmp, mask);
2988 kmovq(ktail_mask, reg_tmp);
2989 }
2990
2991 if (nb_oc_int_no_tail == 0) {
2992 kd_loop(true); // masked
2993 } else if (nb_oc_int_no_tail == 1) {
2994 kd_loop(false); // unmasked!
2995 if (oc_block_int_tail > 0) {
2996 add(reg_ptr_inp, inp_c_step);
2997 add(reg_ptr_out, out_c_step);
2998 kd_loop(true); // masked
2999 }
3000 } else if (nb_oc_int_no_tail > 1) {
3001 mov(reg_cnt_ocb, nb_oc_int_no_tail);
3002 Label label_ocb_loop;
3003 L(label_ocb_loop);
3004 {
3005 kd_loop(false); // unmasked!
3006 add(reg_ptr_inp, inp_c_step);
3007 add(reg_ptr_out, out_c_step);
3008 dec(reg_cnt_ocb);
3009 jnz(label_ocb_loop);
3010 }
3011 if (oc_block_int_tail > 0) kd_loop(true); // masked
3012 }
3013
3014 postamble();
3015}
3016
3017// Tile register decomposition
3018// { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
3019int jit_avx512_core_amx_bwd_data_kernel_t::get_out_tensor(int h, int i) const {
3020 const int C_BASE = 0;
3021 const int C_LAST = 4;
3022 assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
3023 MAYBE_UNUSED(C_LAST);
3024 const int tile = C_BASE + h * jcp.nb_ih_blocking + i;
3025 assert(C_BASE <= tile && tile < C_LAST);
3026 return tile;
3027}
3028int jit_avx512_core_amx_bwd_data_kernel_t::get_inp_tensor(int h) const {
3029 const int I_BASE = 4;
3030 const int I_LAST = 6;
3031 assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
3032 MAYBE_UNUSED(I_LAST);
3033 const int tile = I_BASE + h;
3034 assert(I_BASE <= tile && tile < I_LAST);
3035 return tile;
3036}
3037int jit_avx512_core_amx_bwd_data_kernel_t::get_wei_tensor(int i) const {
3038 const int W_BASE = 6;
3039 const int W_LAST = 8;
3040 assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
3041 MAYBE_UNUSED(W_LAST);
3042 const int tile = W_BASE + i;
3043 assert(W_BASE <= tile && tile < W_LAST);
3044 return tile;
3045}
3046
3047// Strides, shifts and offsets
3048// - inp is a padded buffer ~ [nb_oc_int][kd][ohp][owp]{32c,64c}
3049// - weights is user buffer ~ OIdhw16o16i{2o,4o}
3050// - output is tiled buffer ~ [NBIH][NBIC][tile_width][16c]
3051size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_ocb_step() const {
3052 return (size_t)jcp.typesize_in * jcp.kd * jcp.ohp * jcp.owp
3053 * jcp.oc_block_int;
3054}
3055size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_shift() const {
3056 return (size_t)jcp.typesize_in * jcp.tile_width * jcp.oc_block_int;
3057}
3058size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_d_step() const {
3059 return static_cast<size_t>(jcp.typesize_in) * jcp.ohp * jcp.owp
3060 * jcp.oc_block_int;
3061}
3062size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_offset(
3063 int ihb, int kh, int kw) const {
3064 // calculate offset by src height dimension
3065 size_t sp_offset = (size_t)ihb * jcp.owp;
3066 // add offset by kernel height dimension
3067 sp_offset += (size_t)(jcp.kh - 1 - kh) * (jcp.dilate_h + 1) * jcp.owp;
3068 // add offset by kernel width dimension
3069 sp_offset += (size_t)(jcp.kw - 1 - kw) * (jcp.dilate_w + 1);
3070 return jcp.typesize_in * sp_offset * jcp.oc_block_int;
3071}
3072size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_kh_step() const {
3073 return (size_t)jcp.typesize_in * jcp.kw * jcp.oc_block_int * jcp.ic_block;
3074}
3075size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_ocb_step() const {
3076 const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
3077 return (size_t)jcp.typesize_in * (is_deconv ? 1 : jcp.nb_ic) * jcp.kd
3078 * jcp.kh * jcp.kw * jcp.oc_block_int * jcp.ic_block;
3079}
3080size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_offset(
3081 int icb, int kh, int kw) const {
3082 const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
3083 const size_t wei_kw_stride = jcp.oc_block_int * jcp.ic_block;
3084 const size_t wei_kh_stride = jcp.kw * wei_kw_stride;
3085 const size_t wei_kd_stride = jcp.kh * wei_kh_stride;
3086 const size_t wei_icb_stride
3087 = (is_deconv ? jcp.nb_oc_int : 1) * jcp.kd * wei_kd_stride;
3088 return jcp.typesize_in
3089 * (icb * wei_icb_stride + kh * wei_kh_stride + kw * wei_kw_stride);
3090}
3091size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_d_step() const {
3092 const size_t wei_kd_stride
3093 = jcp.kh * jcp.kw * jcp.oc_block_int * jcp.ic_block;
3094 // step through 'kd' weight elements by `stride_d`
3095 return static_cast<size_t>(jcp.typesize_in) * jcp.stride_d * wei_kd_stride;
3096}
3097size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_icb_offset(
3098 int ihb, int icb) const {
3099 size_t el_offset = jcp.is_nspc
3100 ? (size_t)icb * jcp.ic_block
3101 + (size_t)ihb * jcp.iw * jcp.ngroups
3102 * jcp.ic_without_padding
3103 : (size_t)icb * jcp.id * jcp.ih * jcp.iw * jcp.ic_block
3104 + (size_t)ihb * jcp.iw * jcp.ic_block;
3105 return (size_t)jcp.typesize_out * el_offset;
3106}
3107size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_row_offset(
3108 int ihb, int icb, int j) const {
3109 size_t offset_w = jcp.is_nspc ? (size_t)jcp.typesize_out * j * jcp.ngroups
3110 * jcp.ic_without_padding
3111 : (size_t)jcp.typesize_out * j * jcp.ic_block;
3112 return get_out_icb_offset(ihb, icb) + offset_w;
3113}
3114size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_shift(int width) const {
3115 return jcp.is_nspc ? (size_t)jcp.typesize_out * width * jcp.ngroups
3116 * jcp.ic_without_padding
3117 : (size_t)jcp.typesize_out * width * jcp.ic_block;
3118}
3119size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_icb_offset(
3120 int ihb, int icb) const {
3121 size_t el_offset = (size_t)icb * prv_width_ * jcp.ic_block
3122 + (size_t)ihb * jcp.nb_ic_blocking * jcp.full_tile_width
3123 * jcp.ic_block;
3124 return jcp.typesize_acc * el_offset;
3125}
3126size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_row_offset(
3127 int ihb, int icb, int j) const {
3128 return get_wsp_icb_offset(ihb, icb)
3129 + (size_t)jcp.typesize_acc * j * jcp.ic_block;
3130}
3131
3132// Code generation
3133void jit_avx512_core_amx_bwd_data_kernel_t::prepare_output() {
3134 for (int h = 0; h < jcp.nb_ih_blocking; h++)
3135 for (int i = 0; i < jcp.nb_ic_blocking; i++)
3136 tilezero(Tmm(get_out_tensor(h, i)));
3137}
3138
3139void jit_avx512_core_amx_bwd_data_kernel_t::init_runtime_counters(
3140 bool start_with_last_tile_block) {
3141 prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
3142 ? jcp.tile_tail
3143 : jcp.tile_width;
3144
3145 row_count_ = 0;
3146 is_store_done_ = false;
3147 is_buffer_empty_ = true;
3148}
3149
3150bool jit_avx512_core_amx_bwd_data_kernel_t::maybe_eltwise(int position) {
3151 using namespace primitive_kind;
3152 const auto &p = attr_.post_ops_;
3153
3154 if (position == 0) {
3155 /* eltwise before sum */
3156 return p.contain(eltwise, 0);
3157 } else if (position == 1) {
3158 /* eltwise after sum */
3159 return p.contain(sum, 0) && p.contain(eltwise, 1);
3160 }
3161
3162 return false;
3163}
3164
3165Ymm jit_avx512_core_amx_bwd_data_kernel_t::ymm_mask(
3166 const Ymm &ymm_in, bool mask_flag, bool store) {
3167 return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
3168 : ymm_in;
3169}
3170
3171Zmm jit_avx512_core_amx_bwd_data_kernel_t::zmm_mask(
3172 const Zmm &zmm_in, bool mask_flag, bool store) {
3173 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
3174 : zmm_in;
3175}
3176
3177void jit_avx512_core_amx_bwd_data_kernel_t::cvt2ps(data_type_t type_in,
3178 const Zmm &zmm_in, const Operand &op, bool mask_flag) {
3179 const Zmm zmm = zmm_mask(zmm_in, mask_flag);
3180 switch (type_in) {
3181 case data_type::f32:
3182 case data_type::s32: vmovups(zmm, op); break;
3183 case data_type::s8: vpmovsxbd(zmm, op); break;
3184 case data_type::u8: vpmovzxbd(zmm, op); break;
3185 default: assert(!"unsupported data type");
3186 }
3187 if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in);
3188}
3189
3190void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_bf16(
3191 const Zmm &zmm_out, int icb, int h, int w) {
3192 const bool mask_flag = jcp.is_nspc && jcp.ic_without_padding != jcp.ic
3193 && icb == (jcp.nb_ic_blocking - 1);
3194
3195 auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3196
3197 const auto &p = attr_.post_ops_;
3198
3199 const int sum_idx = p.find(primitive_kind::sum);
3200 if (sum_idx != -1) {
3201 if (jcp.dsrc_dt == data_type::bf16) {
3202 vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
3203 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
3204 vaddps(zmm_out, zmm_prev_dst);
3205 } else {
3206 vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
3207 vaddps(zmm_out, zmm_prev_dst);
3208 }
3209 }
3210 if (jcp.with_bias) {
3211 int bias_offset = jcp.typesize_bia * icb * jcp.ic_block;
3212 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3213 if (jcp.bia_dt == data_type::bf16) {
3214 vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
3215 vpslld(zmm_bias, zmm_bias, 16);
3216 vaddps(zmm_out, zmm_bias);
3217 } else
3218 vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
3219 }
3220
3221 const int eltwise_ind = p.find(primitive_kind::eltwise);
3222 if (eltwise_ind != -1) eltwise_injector_->compute_vector(zmm_out.getIdx());
3223
3224 if (jcp.dsrc_dt == data_type::bf16) {
3225 Ymm ymm_out = Ymm(zmm_out.getIdx());
3226 vcvtneps2bf16(ymm_out, zmm_out);
3227 vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
3228 } else {
3229 vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
3230 }
3231}
3232
3233void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
3234 const Zmm &zmm_out, int icb, int h, int w) {
3235 const int nb_ic_block = jcp.nb_ic_blocking;
3236 const int ic_block = jcp.ic_block;
3237 const bool mask_flag = true && jcp.ic_without_padding != jcp.ic
3238 && icb == (nb_ic_block - 1);
3239
3240 auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3241
3242 const auto &p = attr_.post_ops_;
3243 const int sum_idx = p.find(primitive_kind::sum);
3244 const float *p_sum_scale = nullptr;
3245 const int32_t *p_sum_zp = nullptr;
3246 if (sum_idx != -1) {
3247 const auto &p_entry = p.entry_[sum_idx];
3248 p_sum_scale = &p_entry.sum.scale;
3249 p_sum_zp = &p_entry.sum.zero_point;
3250 }
3251
3252 if (p_sum_scale) {
3253 if (*p_sum_scale != 1.f)
3254 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
3255 if (*p_sum_zp != 0)
3256 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
3257 }
3258
3259 int scale_offset = jcp.is_ic_scale * (sizeof(float) * icb * ic_block);
3260 if (jcp.with_bias) {
3261 int bias_offset = jcp.typesize_bia * icb * ic_block;
3262 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3263 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
3264 }
3265
3266 /* add bias to zmm_accum */
3267 vcvtdq2ps(zmm_out, zmm_out);
3268 const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
3269 vmulps(zmm_out_msk, zmm_out,
3270 EVEX_compress_addr(reg_ptr_scales, scale_offset));
3271 if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
3272
3273 /* Do post-ops */
3274 if (maybe_eltwise(0)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3275 if (p_sum_scale) { // post_op: sum
3276 cvt2ps(jcp.dsrc_dt, zmm_prev_dst, addr, mask_flag);
3277 if (*p_sum_zp != 0) {
3278 vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
3279 vsubps(zmm_prev_dst, zmm_sum_zp);
3280 }
3281 if (*p_sum_scale == 1.f)
3282 vaddps(zmm_out, zmm_prev_dst);
3283 else
3284 vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
3285 }
3286 if (maybe_eltwise(1)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3287
3288 if (jcp.dst_scale) { vmulps(zmm_out_msk, zmm_out, zmm_dst_scale); }
3289
3290 // Properly saturate the accumulators for integer datatypes
3291 if (one_of(jcp.dsrc_dt, u8, s8, s32)) {
3292 init_saturate_f32(
3293 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dsrc_dt);
3294 saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dsrc_dt);
3295 vcvtps2dq(zmm_out, zmm_out);
3296 }
3297
3298 const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
3299
3300 switch (jcp.dsrc_dt) {
3301 case data_type::f32:
3302 case data_type::s32: vmovups(addr, zmm_out_store); break;
3303 case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
3304 case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
3305 default: assert(!"unknown dst_dt");
3306 }
3307}
3308
3309void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector(
3310 const Zmm &zmm_out, int icb, int h, int w) {
3311 /*
3312 Output:
3313 jcp.is_nspc !jcp.is_nspc
3314 ------------------------ ---------------------
3315 INT8: [N][D][H][W][NBIC][16IC]
3316 BF16: [N][D][H][W][NBIC][16IC] or [N][NBIC][D][H][W][16IC]
3317 */
3318 if (jcp.ddst_dt == data_type::bf16) {
3319 store_output_vector_bf16(zmm_out, icb, h, w);
3320 } else {
3321 store_output_vector_int8(zmm_out, icb, h, w);
3322 }
3323}
3324
3325void jit_avx512_core_amx_bwd_data_kernel_t::store_output(
3326 int width, bool do_store) {
3327 auto store_output_block = [=](int width, bool do_store,
3328 bool is_last_ih_blks) {
3329 // Calculate the number of ih blocks; it may differ on last call
3330 const int n_ih_blks = is_last_ih_blks ? jcp.ih % jcp.nb_ih_blocking
3331 : jcp.nb_ih_blocking;
3332 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3333 for (int ihb = 0; ihb < n_ih_blks; ihb++) {
3334 /* Formats: Workspace: [NBIH][NBIC][W][16OC] */
3335 tilestored(ptr[reg_wsp_ptr + reg_wei_stride
3336 + get_wsp_icb_offset(ihb, icb)],
3337 Tmm(get_out_tensor(ihb, icb)));
3338 is_buffer_empty_ = false;
3339 is_store_done_ = false;
3340 for (int tw = 0; tw < width && do_store; tw++) {
3341 Zmm zmm_out = Zmm(tw);
3342 vmovups(zmm_out,
3343 ptr[reg_wsp_ptr
3344 + get_wsp_row_offset(ihb, icb, tw)]);
3345 store_output_vector(zmm_out, icb, ihb, tw);
3346 }
3347 }
3348 }
3349 };
3350
3351 // adjustment in case interleave store is turned off
3352 do_store = do_store || jcp.per_one_pstore == 0;
3353 if (jcp.ih % jcp.nb_ih_blocking == 0) {
3354 store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3355 } else {
3356 Label label_full_store, label_done;
3357 cmp(reg_last_h, 0);
3358 jne(label_full_store, T_NEAR);
3359 store_output_block(width, do_store, /* is_last_ih_blks = */ true);
3360 jmp(label_done, T_NEAR);
3361 L(label_full_store);
3362 store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3363 L(label_done);
3364 }
3365 if (do_store) add(reg_out_ptr, get_out_shift(width));
3366}
3367
3368void jit_avx512_core_amx_bwd_data_kernel_t::interleave_store(int width) {
3369 for (int c = 0;
3370 c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
3371 c++) {
3372 // row_count = ihb * ICB * TW + icb * TW + tw
3373 int tw = row_count_ % prv_width_;
3374 int icb = (row_count_ / prv_width_) % jcp.nb_ic_blocking;
3375 int ihb = (row_count_ / prv_width_) / jcp.nb_ic_blocking;
3376
3377 Zmm zmm_out = Zmm(tw);
3378 vmovups(zmm_out, ptr[reg_wsp_ptr + get_wsp_row_offset(ihb, icb, tw)]);
3379 store_output_vector(zmm_out, icb, ihb, tw);
3380 row_count_++;
3381
3382 if (row_count_
3383 == prv_width_ * jcp.nb_ic_blocking * jcp.nb_ih_blocking) {
3384 add(reg_out_ptr, get_out_shift(prv_width_));
3385
3386 is_store_done_save_ = is_store_done_;
3387 prv_width_save_ = prv_width_;
3388
3389 row_count_ = 0;
3390 is_store_done_ = true;
3391 prv_width_ = width;
3392 }
3393 }
3394}
3395
3396void jit_avx512_core_amx_bwd_data_kernel_t::skipped_interleave_store() {
3397
3398 if (is_store_done_save_ || is_buffer_empty_) return;
3399
3400 const int store_count
3401 = prv_width_save_ * jcp.nb_ic_blocking * jcp.nb_ih_blocking;
3402 for (int row_count = 0; row_count < store_count; row_count++) {
3403 // row_count = ihb * ICB * TW + icb * TW + tw
3404 int tw = row_count % prv_width_save_;
3405 int icb = (row_count / prv_width_save_) % jcp.nb_ic_blocking;
3406 int ihb = (row_count / prv_width_save_) / jcp.nb_ic_blocking;
3407
3408 Zmm zmm_out = Zmm(tw);
3409 vmovups(zmm_out, ptr[reg_wsp_ptr + get_wsp_row_offset(ihb, icb, tw)]);
3410 store_output_vector(zmm_out, icb, ihb, tw);
3411 }
3412 is_store_done_save_ = true;
3413 add(reg_out_ptr, get_out_shift(prv_width_save_));
3414}
3415
3416void jit_avx512_core_amx_bwd_data_kernel_t::compute_ocb_loop(
3417 int width, bool do_interleave_store) {
3418
3419 auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
3420 switch (jcp.ddst_dt) {
3421 using namespace data_type;
3422 case bf16: tdpbf16ps(x1, x2, x3); break;
3423 case s8: tdpbssd(x1, x2, x3); break;
3424 case u8: tdpbusd(x1, x2, x3); break;
3425 default: assert(!"unsupported data type");
3426 }
3427 };
3428
3429 for (int ocb = 0; ocb < jcp.nb_oc_int; ocb++) {
3430 // reverse order through spatial components of weights so that
3431 // input buffer is accessed in a monotonically increasing fashion
3432 for (int kh = jcp.kh - 1; kh >= 0; kh--) {
3433 for (int kw = jcp.kw - 1; kw >= 0; kw--) {
3434 for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3435 tileloadd(Tmm(get_inp_tensor(ihb)),
3436 ptr[reg_inp_ptr + get_inp_offset(ihb, kh, kw)
3437 + reg_inp_stride]);
3438 }
3439 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3440 tileloadd(Tmm(get_wei_tensor(icb)),
3441 ptr[reg_wei_ptr + get_wei_offset(icb, kh, kw)
3442 + reg_wei_stride]);
3443 for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3444 tdpbxxd(Tmm(get_out_tensor(ihb, icb)),
3445 Tmm(get_inp_tensor(ihb)),
3446 Tmm(get_wei_tensor(icb)));
3447 if (do_interleave_store) interleave_store(width);
3448 }
3449 }
3450 }
3451 }
3452 add(reg_inp_ptr, get_inp_ocb_step());
3453 add(reg_wei_ptr, get_wei_ocb_step());
3454 }
3455 sub(reg_inp_ptr, get_inp_ocb_step() * jcp.nb_oc_int);
3456 sub(reg_wei_ptr, get_wei_ocb_step() * jcp.nb_oc_int);
3457}
3458
3459void jit_avx512_core_amx_bwd_data_kernel_t::compute_kd_loop(
3460 int width, bool do_store, bool handle_skipped_stores) {
3461
3462 Label skip_compute_kd_label, kd_loop_label, end_kd_compute_label;
3463
3464 prepare_output();
3465
3466 if (jcp.ndims == 5) {
3467 push(reg_inp_ptr);
3468 push(reg_wei_ptr);
3469
3470 mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
3471 cmp(reg_kd, 0);
3472 jle(skip_compute_kd_label, T_NEAR);
3473 }
3474
3475 compute_ocb_loop(width, true);
3476
3477 if (jcp.ndims == 5) {
3478 L(kd_loop_label);
3479
3480 // for bwd_d, filter elements are stepped through in reverse, e.g.:
3481 // diff_dst: [0, 1, 2, 3, 4] diff_src:
3482 // wei: [2, 1, 0] [0]
3483 // [2, 1, 0] [1]
3484 // [2, 1, 0] [2]
3485 // which results in 'data_copy_kernel_t' copying the 'kd' dimension
3486 // corresponding to 'diff_dst' elements in reverse. The layout for the
3487 // 'kd' elements of 'inp_buff' to compute 'diff_src[0]' are:
3488 //
3489 // [kd]:
3490 // inp_buff: [nb_oc_int]{2,1,0}[ohp][owp]{32c,64c}
3491 //
3492 // then,
3493 // inp_buff: [nb_oc_int]{3,2,1}[ohp][owp]{32c,64c}
3494 //
3495 // and so on...
3496 //
3497 // hence, step through 'reg_inp_ptr' in ascending order:
3498 add(reg_inp_ptr, get_inp_d_step());
3499 add(reg_wei_ptr, get_wei_d_step());
3500
3501 dec(reg_kd);
3502 jz(end_kd_compute_label, T_NEAR);
3503
3504 // because 'kd' is dynamic, it may skip interleaved stores, so do not
3505 // unroll for more than one kd iteration.
3506 compute_ocb_loop(width, false);
3507 jmp(kd_loop_label, T_NEAR);
3508
3509 L(skip_compute_kd_label);
3510
3511 // 'kd_padding' may be '0' due to gaps in the filter / diff_dst
3512 // computation due to stride or dilation, so 'compute_ocb_loop()' may be
3513 // skipped. This results in skipping interleaved_stores() as well. So
3514 // call a special case of 'interleave_store()' for such cases.
3515 if (handle_skipped_stores) skipped_interleave_store();
3516
3517 L(end_kd_compute_label);
3518
3519 pop(reg_wei_ptr);
3520 pop(reg_inp_ptr);
3521 }
3522
3523 store_output(width, do_store);
3524
3525 add(reg_inp_ptr, get_inp_shift());
3526}
3527
3528void jit_avx512_core_amx_bwd_data_kernel_t::compute_iw_loop() {
3529 auto compute_iw_loop_body = [=](bool last_iwb, int num_tile_blocks) {
3530 // check if there are '0' gaps in input stores due to dilation or stride
3531 bool handle_skipped_stores = gaps_in_store() && num_tile_blocks > 1;
3532
3533 int gen_tile_tail = last_iwb && jcp.tile_tail > 0 ? jcp.tile_tail
3534 : jcp.tile_width;
3535 init_runtime_counters(last_iwb && num_tile_blocks == 1);
3536 for (int iwb = 0; iwb < num_tile_blocks - 1; iwb++)
3537 compute_kd_loop(jcp.tile_width, false, handle_skipped_stores);
3538 compute_kd_loop(gen_tile_tail, true, handle_skipped_stores);
3539 };
3540
3541 if (jcp.nb_iw == 1) {
3542 compute_iw_loop_body(true, jcp.iw_blocks);
3543 } else {
3544 Label label_done;
3545 int iw_blocks_per_call = div_up(jcp.iw_block, jcp.tile_width);
3546 int last_iwb_tile_blocks = jcp.iw_blocks % iw_blocks_per_call;
3547 if (last_iwb_tile_blocks == 0 && jcp.tile_tail > 0)
3548 last_iwb_tile_blocks = iw_blocks_per_call;
3549 if (last_iwb_tile_blocks > 0) {
3550 Label label_not_last_iwb;
3551 mov(reg_tmp, ptr[param1 + GET_OFF(iwb)]);
3552 cmp(reg_tmp, jcp.nb_iw - 1);
3553 jne(label_not_last_iwb, T_NEAR);
3554
3555 compute_iw_loop_body(true, last_iwb_tile_blocks);
3556
3557 jmp(label_done, T_NEAR);
3558
3559 L(label_not_last_iwb);
3560 }
3561 compute_iw_loop_body(false, iw_blocks_per_call);
3562
3563 L(label_done);
3564 }
3565}
3566
3567void jit_avx512_core_amx_bwd_data_kernel_t::generate() {
3568 preamble();
3569
3570 mov(reg_inp_ptr, ptr[param1 + GET_OFF(dst)]); // padded buffer of diff_dst
3571 mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]); // weights
3572 mov(reg_out_ptr, ptr[param1 + GET_OFF(src)]); // diff_src
3573 mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
3574
3575 if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
3576
3577 if (jcp.dst_scale) {
3578 mov(reg_ptr_dst_scales, ptr[param1 + GET_OFF(dst_scale)]);
3579 vmovups(zmm_dst_scale, EVEX_compress_addr(reg_ptr_dst_scales, 0));
3580 }
3581 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
3582
3583 mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
3584
3585 const int inp_stride = jcp.oc_block_int * jcp.typesize_in;
3586 const int wei_stride = jcp.ic_block * jcp.typesize_acc;
3587 mov(reg_inp_stride, inp_stride);
3588 mov(reg_wei_stride, wei_stride);
3589
3590 if (jcp.is_nspc && jcp.ic_without_padding != jcp.ic) {
3591 // Use mask 0xF by default for all output data and post-ops
3592 // loads / stores with block index
3593 // icb = icc * jcp.nb_ic_blocking + (jcp.nb_ic_blocking - 1)
3594 // TODO: use masked loads / stores for the last icc only
3595 int current_block_size = jcp.ic_block;
3596 int mask = (1 << current_block_size) - 1;
3597 Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
3598 mov(regw_tmp, mask);
3599 kmovw(ktail_mask, regw_tmp);
3600 Xbyak::Label mask_is_set;
3601 mov(reg_ic_blocks, ptr[param1 + GET_OFF(ic_blocks)]);
3602 cmp(reg_ic_blocks, jcp.nb_ic - jcp.nb_ic_blocking);
3603 jne(mask_is_set, T_NEAR);
3604 // Reset the mask
3605 current_block_size = jcp.ic_without_padding % jcp.ic_block;
3606 mask = (1 << current_block_size) - 1;
3607 mov(regw_tmp, mask);
3608 kmovw(ktail_mask, regw_tmp);
3609
3610 L(mask_is_set);
3611 }
3612 compute_iw_loop();
3613
3614 postamble();
3615
3616 if (jcp.with_eltwise) eltwise_injector_->prepare_table();
3617}
3618
3619bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok(
3620 const jit_conv_conf_t &jcp, primitive_attr_t &attr) {
3621 using namespace primitive_kind;
3622 const auto &p = attr.post_ops_;
3623 const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
3624
3625 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
3626
3627 auto is_sum = [&](int idx) {
3628 if (is_bf16)
3629 return p.entry_[idx].is_sum();
3630 else
3631 return p.contain(sum, idx);
3632 };
3633
3634 switch (p.len()) {
3635 case 0: return true;
3636 case 1: return is_eltwise(0) || is_sum(0);
3637 case 2:
3638 return (is_sum(0) && is_eltwise(1))
3639 || (!is_bf16 && is_sum(1) && is_eltwise(0));
3640 default: return false;
3641 }
3642
3643 return false;
3644}
3645
3646void jit_avx512_core_amx_bwd_data_kernel_t::tile_configure(char *tcfg_buff) {
3647 const int vnni_width = jcp.ddst_dt == data_type::bf16 ? 2 : 4;
3648 // Input tile dimensions
3649 const int a_col = jcp.oc_block_int;
3650 const int a_row = jcp.tile_width;
3651 // Weights tile dimensions
3652 const int b_col = jcp.ic_block * vnni_width;
3653 const int b_row = a_col / vnni_width;
3654 // Accumulator tile dimensions
3655 const int c_col = jcp.ic_block;
3656 const int c_row = a_row;
3657
3658 for (size_t i = 0; i < 64; i++)
3659 tcfg_buff[i] = 0;
3660
3661 // Weights (W_BASE) Tensor Tiles
3662 for (int i = 0; i < jcp.nb_ic_blocking; i++)
3663 tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
3664 b_row, b_col * jcp.typesize_in);
3665
3666 // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
3667 for (int h = 0; h < jcp.nb_ih_blocking; h++) {
3668 tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
3669 a_row, a_col * jcp.typesize_in);
3670 for (int i = 0; i < jcp.nb_ic_blocking; i++)
3671 tc_configure_tile((palette_config_t *)tcfg_buff,
3672 get_out_tensor(h, i), c_row, c_col * jcp.typesize_acc);
3673 }
3674
3675 ((palette_config_t *)tcfg_buff)->palette_id = amx::get_target_palette();
3676}
3677
3678status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
3679 const convolution_desc_t &cd, memory_desc_t &diff_src_md,
3680 memory_desc_t &weights_md, memory_desc_t &diff_dst_md,
3681 memory_desc_t *bias_md, primitive_attr_t &attr, int nthreads) {
3682 using namespace prop_kind;
3683
3684 const memory_desc_wrapper diff_src_d(&diff_src_md);
3685 const memory_desc_wrapper weights_d(&weights_md);
3686 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
3687 const memory_desc_wrapper bias_d(bias_md);
3688
3689 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
3690 int ndims = diff_src_d.ndims();
3691 bool is_1d = ndims == 3;
3692 bool is_3d = ndims == 5;
3693
3694 using namespace data_type;
3695 const bool is_deconv = cd.prop_kind != prop_kind::backward_data;
3696
3697 const bool is_bf16 = everyone_is(true, diff_dst_d.data_type() == bf16,
3698 weights_d.data_type() == bf16,
3699 one_of(diff_src_d.data_type(), bf16, f32));
3700 const bool is_bf16_convolution = is_bf16 && !is_deconv;
3701 const bool is_bf16_deconvolution = is_bf16 && is_deconv;
3702 const bool is_int8_deconvolution = is_deconv
3703 && everyone_is(true, one_of(diff_dst_d.data_type(), s8, u8),
3704 weights_d.data_type() == s8,
3705 one_of(diff_src_d.data_type(), f32, s32, s8, u8));
3706
3707 bool supported
3708 = mayiuse(avx512_core_amx) && (is_bf16 || is_int8_deconvolution);
3709 if (!supported) return status::unimplemented;
3710
3711 jcp = zero<decltype(jcp)>();
3712 jcp.isa = avx512_core_amx;
3713 jcp.ndims = ndims;
3714 jcp.prop_kind = cd.prop_kind;
3715 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
3716
3717 jcp.mb = diff_src_d.dims()[0];
3718 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
3719 jcp.oc_without_padding = jcp.oc;
3720 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
3721 jcp.ic_without_padding = jcp.ic;
3722 jcp.id = is_3d ? diff_src_d.dims()[2] : 1;
3723 jcp.ih = !is_1d ? diff_src_d.dims()[ndims - 2] : 1;
3724 jcp.iw = diff_src_d.dims()[ndims - 1];
3725 jcp.od = is_3d ? diff_dst_d.dims()[2] : 1;
3726 jcp.oh = !is_1d ? diff_dst_d.dims()[ndims - 2] : 1;
3727 jcp.ow = diff_dst_d.dims()[ndims - 1];
3728 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
3729 jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
3730 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
3731 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
3732 jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
3733 jcp.l_pad = cd.padding[0][ndims - 3];
3734 jcp.stride_d = is_3d ? cd.strides[0] : 1;
3735 jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
3736 jcp.stride_w = cd.strides[ndims - 3];
3737
3738 // No bias for bf16 case to simplify integration with ref_deconvolution
3739 jcp.with_bias = bias_md && !is_bf16_convolution
3740 && cd.bias_desc.format_kind != format_kind::undef;
3741
3742 jcp.dilate_d = is_3d ? cd.dilates[ndims - 5] : 0;
3743 jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
3744 jcp.dilate_w = cd.dilates[ndims - 3];
3745
3746 if (jcp.dilate_d != 0 && jcp.stride_d != 1) return status::unimplemented;
3747
3748 const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1;
3749 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
3750 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
3751 jcp.back_pad = calculate_end_padding(
3752 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, gen_kd);
3753 jcp.b_pad = calculate_end_padding(
3754 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
3755 jcp.r_pad = calculate_end_padding(
3756 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
3757 if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
3758 || jcp.b_pad >= gen_kh || jcp.f_pad >= gen_kd
3759 || jcp.back_pad >= gen_kd)
3760 return status::unimplemented;
3761
3762 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
3763 if (is_deconv) {
3764 jcp.ddst_dt = cd.src_desc.data_type;
3765 jcp.dsrc_dt = cd.dst_desc.data_type;
3766 } else {
3767 jcp.ddst_dt = cd.diff_dst_desc.data_type;
3768 jcp.dsrc_dt = cd.diff_src_desc.data_type;
3769 }
3770 jcp.wei_dt = cd.weights_desc.data_type;
3771
3772 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
3773
3774 if (jcp.is_depthwise)
3775 return status::unimplemented; // TODO: add support of DW convolution
3776
3777 format_tag_t dat_tag_ncsp = pick(ndims - 3, format_tag::nCw16c,
3778 format_tag::nChw16c, format_tag::nCdhw16c);
3779 format_tag_t dat_tag_nspc = pick(
3780 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
3781 // To toggle the default data layout for BF16 between nChw16c and nhwc,
3782 // swap the following two variable definitions. Current choice: nhwc.
3783 format_tag_t dat_tag_opt = dat_tag_nspc;
3784 format_tag_t dat_tag_alt = is_bf16 ? dat_tag_ncsp : dat_tag_nspc;
3785
3786 if (diff_src_d.format_kind() == format_kind::any) {
3787 CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag_opt));
3788 jcp.src_tag = dat_tag_opt;
3789 } else
3790 jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
3791
3792 if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
3793 return status::unimplemented;
3794
3795 jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
3796 assert(IMPLICATION(is_int8_deconvolution, jcp.is_nspc));
3797
3798 // TODO: remove all support for nChw16c from this implementation
3799 if (!jcp.is_nspc) return status::unimplemented;
3800
3801 if (diff_dst_d.format_kind() == format_kind::any) {
3802 CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
3803 jcp.dst_tag = jcp.src_tag;
3804 } else
3805 jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
3806
3807 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
3808
3809 if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
3810 CHECK(memory_desc_init_by_tag(*bias_md, format_tag::x));
3811
3812 jcp.nthr = nthreads;
3813
3814 jcp.ic_block = 16;
3815 jcp.oc_block = 16;
3816
3817 if (jcp.ngroups == 1) {
3818 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
3819 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
3820 }
3821 bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
3822 if (!args_ok) return status::unimplemented;
3823
3824 const int vnni_width = is_bf16 ? 2 : 4;
3825 jcp.oc_block_int = jcp.oc_block * vnni_width; // 32 for bf16, 64 for int8
3826
3827 if (attr.set_default_formats(&diff_src_md) != status::success)
3828 return status::unimplemented;
3829 if (!post_ops_ok(jcp, attr)) return status::unimplemented;
3830
3831 const auto &p = attr.post_ops_;
3832 const int eltwise_ind = p.find(primitive_kind::eltwise);
3833 jcp.with_eltwise = eltwise_ind != -1;
3834 if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
3835
3836 auto set_or_check_wei_format = [&]() {
3837 using namespace format_tag;
3838 format_tag_t wei_tag;
3839 if (is_bf16_convolution)
3840 wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16o16i2o,
3841 gOIw16o16i2o, OIhw16o16i2o, gOIhw16o16i2o, OIdhw16o16i2o,
3842 gOIdhw16o16i2o);
3843 else if (is_bf16_deconvolution)
3844 wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o2i,
3845 gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i, OIdhw16i16o2i,
3846 gOIdhw16i16o2i);
3847 else if (is_int8_deconvolution)
3848 wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o4i,
3849 gOIw16i16o4i, OIhw16i16o4i, gOIhw16i16o4i, OIdhw16i16o4i,
3850 gOIdhw16i16o4i);
3851 else {
3852 assert(!"unsupported combination");
3853 return false;
3854 }
3855
3856 memory_desc_t want_wei_md = weights_md;
3857 memory_desc_init_by_tag(want_wei_md, wei_tag);
3858
3859 if (weights_md.format_kind == format_kind::any) {
3860 weights_md = want_wei_md;
3861 return true;
3862 }
3863 return weights_md == want_wei_md;
3864 };
3865
3866 if (!set_or_check_wei_format()) return status::unimplemented;
3867
3868 jcp.typesize_in = types::data_type_size(diff_dst_d.data_type());
3869 jcp.typesize_out = types::data_type_size(diff_src_d.data_type());
3870 jcp.typesize_bia
3871 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
3872 jcp.typesize_acc = sizeof(int32_t);
3873
3874 jcp.nb_ic = jcp.ic / jcp.ic_block;
3875 jcp.nb_oc = jcp.oc / jcp.oc_block;
3876 jcp.nb_oc_int = div_up(jcp.oc, jcp.oc_block_int);
3877
3878 const int target_palette = amx::get_target_palette();
3879 jcp.max_tiles = amx::get_max_tiles(target_palette);
3880 jcp.full_tile_width = amx::get_max_rows(target_palette);
3881 if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
3882 return status::unimplemented;
3883
3884 jcp.tile_width = nstl::min(jcp.full_tile_width, jcp.iw);
3885 jcp.iw_blocks = div_up(jcp.iw, jcp.tile_width);
3886
3887 // Prefer to use a single tile width when possible
3888 // (eg iw28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
3889 if (jcp.iw % jcp.iw_blocks == 0) jcp.tile_width = jcp.iw / jcp.iw_blocks;
3890 jcp.tile_tail = jcp.iw % jcp.tile_width;
3891
3892 jcp.nb_ic_blocking = (jcp.nb_ic % 2 == 0) ? 2 : 1;
3893 jcp.nb_ih_blocking
3894 = everyone_is(true, jcp.ih > 1,
3895 // requirement for interleave stores
3896 IMPLICATION(jcp.iw_blocks > 1, jcp.ih % 2 == 0))
3897 ? 2
3898 : 1;
3899
3900 // TODO: tune ih blocking
3901 const int ih_blk_size_tmp = 10;
3902 const int ih_step = jcp.nb_ih_blocking;
3903 jcp.ih_blk_size = rnd_up(nstl::min(jcp.ih, ih_blk_size_tmp), ih_step);
3904 // ohp includes all elements that are really used in calculation,
3905 // including zero-padded "dilate-by-strides" and top and bottom overflow
3906 jcp.ohp = jcp.ih_blk_size + gen_kh - 1;
3907
3908 // TODO: tune iw blocking
3909 const int iw_blocks_per_call = 2;
3910 jcp.iw_block = jcp.tile_width * iw_blocks_per_call;
3911 jcp.nb_iw = div_up(jcp.iw, jcp.iw_block);
3912 // owp includes all elements that are really used in calculation,
3913 // including zero-padded "dilate-by-strides" and left and right overflow
3914 jcp.owp = jcp.iw_block + gen_kw - 1;
3915
3916 // Number of ops per tile store
3917 int ops_tile_store = jcp.tile_width;
3918 // Number of ops per accumulation tile
3919 int avaliable_ops = jcp.nb_oc_int * jcp.kh * jcp.kw;
3920 // Number of vectors to store per tile operation
3921 // NOTE: set to zero to turn off interleave store (mostly for debugging)
3922 jcp.per_one_pstore = div_up(ops_tile_store, avaliable_ops);
3923
3924 jcp.inp_buffer_size = static_cast<size_t>(jcp.nb_oc_int) * jcp.kd * jcp.ohp
3925 * jcp.owp * jcp.oc_block_int;
3926 jcp.wsp_buffer_size = (size_t)jcp.nb_ih_blocking * jcp.nb_ic_blocking
3927 * jcp.full_tile_width * jcp.ic_block;
3928
3929 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
3930 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
3931 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
3932 const int wei_mask_per_ic = 1 << (int)with_groups;
3933 jcp.is_ic_scale = wei_scales.mask_ == wei_mask_per_ic;
3934 jcp.dst_scale = !dst_scales.has_default_values();
3935
3936 // only common src & dst scales are supported
3937 // only common and per-oc-channel weight scales are supported
3938 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_ic)
3939 && everyone_is(src_scales.mask_, dst_scales.mask_, 0);
3940 if (!scales_ok) return status::unimplemented;
3941
3942 return status::success;
3943}
3944
3945void jit_avx512_core_amx_bwd_data_kernel_t::init_scratchpad(
3946 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
3947 const primitive_attr_t &attr) {
3948
3949 size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
3950 scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
3951 size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
3952 scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
3953 if (jcp.with_bias && jcp.ic != jcp.ic_without_padding) {
3954 assert(jcp.ngroups == 1);
3955 scratchpad.book(key_conv_padded_bias, jcp.ic, jcp.typesize_bia);
3956 }
3957 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
3958
3959 book_precomputed_scales(
3960 scratchpad, attr.scales_, jcp.ngroups * jcp.ic_without_padding);
3961}
3962
3963const int jit_avx512_core_amx_bwd_weights_kernel_t::max_ur_w = 32;
3964
3965// Tile register decomposition
3966// { C_BASE = 0, A_BASE = 4, B_BASE = 6, }
3967int jit_avx512_core_amx_bwd_weights_kernel_t::get_wei_tensor(
3968 int ocb, int icb) const {
3969 const int C_BASE = 0;
3970 const int C_LAST = 4;
3971 assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
3972 MAYBE_UNUSED(C_LAST);
3973 const int tile = C_BASE + ocb * jcp.nb_oc_blocking + icb;
3974 assert(C_BASE <= tile && tile < C_LAST);
3975 return tile;
3976}
3977int jit_avx512_core_amx_bwd_weights_kernel_t::get_src_tensor(int icb) const {
3978 const int A_BASE = 4;
3979 const int A_LAST = 6;
3980 assert(0 <= A_BASE && A_BASE < A_LAST && A_LAST <= jcp.max_tiles);
3981 MAYBE_UNUSED(A_LAST);
3982 const int tile = A_BASE + icb;
3983 assert(A_BASE <= tile && tile < A_LAST);
3984 return tile;
3985}
3986int jit_avx512_core_amx_bwd_weights_kernel_t::get_ddst_tensor(int ocb) const {
3987 const int B_BASE = 6;
3988 const int B_LAST = 8;
3989 assert(0 <= B_BASE && B_BASE < B_LAST && B_LAST <= jcp.max_tiles);
3990 MAYBE_UNUSED(B_LAST);
3991 const int tile = B_BASE + ocb;
3992 assert(B_BASE <= tile && tile < B_LAST);
3993 return tile;
3994}
3995
3996void jit_avx512_core_amx_bwd_weights_kernel_t::tile_configure(char *tcfg_buff) {
3997 // Input tile dimensions
3998 const int a_col = jcp.ur_w;
3999 const int a_row = jcp.ic_block;
4000 // Weights tile dimensions
4001 const int b_col = jcp.oc_block * 2;
4002 const int b_row = a_col / 2;
4003 // Accumulator tile dimensions
4004 const int c_col = jcp.oc_block;
4005 const int c_row = a_row;
4006
4007 for (size_t i = 0; i < 64; i++)
4008 tcfg_buff[i] = 0;
4009
4010 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
4011 tc_configure_tile((palette_config_t *)tcfg_buff, get_src_tensor(icb),
4012 a_row, a_col * jcp.typesize_in);
4013
4014 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
4015 tc_configure_tile((palette_config_t *)tcfg_buff, get_ddst_tensor(ocb),
4016 b_row, b_col * jcp.typesize_in);
4017
4018 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
4019 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
4020 tc_configure_tile((palette_config_t *)tcfg_buff,
4021 get_wei_tensor(ocb, icb), c_row, c_col * jcp.typesize_out);
4022
4023 ((palette_config_t *)tcfg_buff)->palette_id = amx::get_target_palette();
4024}
4025
4026void jit_avx512_core_amx_bwd_weights_kernel_t::od_step_comeback_pointers() {
4027 Label kd_comeback_label;
4028 mov(kj, reg_kd_count);
4029 L(kd_comeback_label);
4030 {
4031 sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
4032 sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
4033 dec(kj);
4034 jnz(kd_comeback_label, T_NEAR);
4035 }
4036}
4037
4038void jit_avx512_core_amx_bwd_weights_kernel_t::oh_step_comeback_pointers() {
4039 Label kh_comeback_label;
4040 mov(kj, reg_kh);
4041 L(kh_comeback_label);
4042 {
4043 sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
4044 sub(reg_kernel, get_kernel_offset(0, jcp.kw));
4045 dec(kj);
4046 jnz(kh_comeback_label, T_NEAR);
4047 }
4048}
4049
4050void jit_avx512_core_amx_bwd_weights_kernel_t::compute_full_spat_loop(
4051 int nb_ic_blocking, int nb_oc_blocking) {
4052 // General code layout:
4053 //
4054 // Blocking over OH -- top level
4055 // (Reduces L2 pressure; not very useful right now)
4056 // Loop over all KHxKW kernel -- emit_kh_kw_loop()
4057 // Loop over OH block -- emit_h_loop()
4058 // Loop over OW blocks -- emit_fma_block()
4059 // (Supports both fully unrolled and partially unrolled
4060 // versions to reduce code size)
4061 // Loop over OW block -- emit_fma_step()
4062
4063 auto src_row_size = get_src_offset(0, 0, 1);
4064 auto ddst_row_size = get_ddst_offset(0, 1);
4065 auto row_size = src_row_size + ddst_row_size;
4066
4067 int h_block_size = jcp.oh;
4068 int h_last_block_size = h_block_size;
4069 int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
4070 auto working_set_size = row_size * h_block_size;
4071
4072 if (working_set_size > full_spat_max_working_set_size) {
4073 assert(full_spat_opt_working_set_size < full_spat_max_working_set_size);
4074
4075 while (working_set_size > full_spat_opt_working_set_size
4076 && h_block_size >= min_h_block_size) {
4077 for (int i = 2; i <= h_block_size; i++)
4078 if (i == h_block_size)
4079 h_block_size = h_block_size / 2;
4080 else if (h_block_size % i == 0) {
4081 h_block_size = h_block_size / i;
4082 break;
4083 }
4084 working_set_size = row_size * h_block_size;
4085 }
4086 h_block_size = nstl::max(min_h_block_size, h_block_size);
4087 h_last_block_size = jcp.oh % h_block_size;
4088 if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size;
4089 }
4090
4091 Opmask reg_h_block = k1;
4092 Reg64 reg_kh = rax;
4093 Reg64 reg_kw = rbx;
4094 Reg64 reg_tmp = abi_not_param1;
4095 Reg32 reg_tmp_w = reg_tmp.cvt32();
4096 Reg64 reg_ohs = rdx;
4097 Reg64 reg_ihs = rsi;
4098 Reg64 reg_h = r8;
4099 Reg64 reg_j = r10;
4100
4101 Reg64 reg_src = r13;
4102 Reg64 reg_ddst = r14;
4103 Reg64 reg_ker = r15;
4104
4105 Reg64 reg_dense_stride = abi_param1;
4106 Reg64 reg_a_stride = reg_tmp;
4107
4108 auto emit_block = [&]() {
4109 mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
4110 for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
4111 dim_t ur_w_src_offset = ur_w_b * get_src_offset(0, jcp.ur_w);
4112 dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
4113
4114 for (int icb = 0; icb < nb_ic_blocking; icb++) {
4115 dim_t icb_offset = jcp.typesize_in * icb * jcp.tr_src_buf_size;
4116 tileloadd(Tmm(get_src_tensor(icb)),
4117 ptr[reg_src + reg_a_stride + icb_offset
4118 + ur_w_src_offset]);
4119 }
4120 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4121 tileloadd(Tmm(get_ddst_tensor(ocb)),
4122 ptr[reg_ddst + reg_dense_stride
4123 + jcp.typesize_in * ocb
4124 * jcp.tr_diff_dst_buf_size
4125 + ur_w_ddst_offset]);
4126 for (int icb = 0; icb < nb_ic_blocking; icb++)
4127 tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
4128 Tmm(get_src_tensor(icb)),
4129 Tmm(get_ddst_tensor(ocb)));
4130 }
4131 }
4132 };
4133
4134 auto emit_h_loop = [&]() {
4135 Label h_loop, skip_h_loop;
4136 mov(reg_j, 1);
4137 cmp(reg_j, reg_h);
4138 je(skip_h_loop, T_NEAR);
4139 L(h_loop);
4140 {
4141 emit_block();
4142
4143 add(reg_src, get_src_offset(0, 0, 1));
4144 add(reg_ddst, get_ddst_offset(0, 1));
4145 add(reg_j, 1);
4146 cmp(reg_j, reg_h);
4147 jb(h_loop);
4148 }
4149 L(skip_h_loop);
4150
4151 emit_block();
4152 };
4153
4154 auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) {
4155 xor_(reg_kh, reg_kh);
4156 Label kh_loop, kh_loop_end;
4157
4158 int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size;
4159 // NB: this is correct because we only support t_pad = kh / 2 and thus
4160 // ih == oh
4161 int ih_block_size = oh_block_size
4162 + (!is_first_block + !is_last_block) * jcp.t_pad;
4163
4164 L(kh_loop);
4165 {
4166 if (is_first_block) {
4167 xor_(reg_tmp, reg_tmp);
4168 mov(reg_ohs, jcp.t_pad);
4169 sub(reg_ohs, reg_kh);
4170 cmovb(reg_ohs, reg_tmp);
4171
4172 mov(reg_ihs, reg_ohs);
4173 sub(reg_ihs, jcp.t_pad);
4174 add(reg_ihs, reg_kh);
4175 } else {
4176 xor_(reg_ohs, reg_ohs);
4177 mov(reg_ihs, reg_kh);
4178 }
4179
4180 mov(reg_tmp, oh_block_size);
4181 sub(reg_tmp, reg_ohs);
4182 mov(reg_h, ih_block_size);
4183 sub(reg_h, reg_ihs);
4184 cmp(reg_tmp, reg_h);
4185 cmovb(reg_h, reg_tmp);
4186
4187 Label kh_loop_work;
4188 cmp(reg_h, 0);
4189 jg(kh_loop_work, T_NEAR);
4190
4191 // empty h loop for this jcp.kh:
4192 // - set the ddst to 0 if necessary
4193 // - move ker pt
4194 // - jump to the end
4195 sub(reg_h, 1);
4196 Label skip_ker_zeroing;
4197
4198 // The reg_ker ptr has highest bit set if the ddst needs to be
4199 // zeroed. Those who have byte-aligned their data will suffer the
4200 // consequences :(
4201 // TODO: move the flag to a mask register? (Roma)
4202 test(reg_ker, 1);
4203 jz(skip_ker_zeroing, T_NEAR);
4204
4205 Label zeroing_loop;
4206 vpxord(zmm0, zmm0, zmm0);
4207 and_(reg_ker, ~1); // temporarily clear the zeroing flag
4208
4209 mov(reg_dense_stride, 64);
4210 tilezero(Tmm(get_wei_tensor(0, 0)));
4211 for (int kw = 0; kw < jcp.kw; kw++) {
4212 // dim_t kw_offset = kw * get_kernel_offset(jcp.ic_block, 0);
4213 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4214 for (int icb = 0; icb < nb_ic_blocking; icb++)
4215 tilestored(
4216 ptr[reg_ker + reg_dense_stride
4217 + get_full_kernel_offset(ocb, icb, 0, kw)],
4218 Tmm(get_wei_tensor(0, 0)));
4219 }
4220 // restore the zeroing flag (it will be cleared after the end of
4221 // emit_kh_kw_loop, but we may need it until then)
4222 or_(reg_ker, 1);
4223 jmp(kh_loop_end, T_NEAR);
4224
4225 L(skip_ker_zeroing);
4226 add(reg_ker, get_kernel_offset(0, jcp.kw));
4227 jmp(kh_loop_end, T_NEAR);
4228
4229 L(kh_loop_work);
4230
4231 mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1));
4232 mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1));
4233
4234 add(reg_src, reg_ihs);
4235 add(reg_ddst, reg_ohs);
4236
4237 Label kw_loop;
4238 xor_(reg_kw, reg_kw);
4239
4240 mov(reg_dense_stride, 64);
4241 L(kw_loop);
4242 {
4243 Label do_zero, ker_init_done;
4244 test(reg_ker, 1);
4245 jnz(do_zero, T_NEAR);
4246
4247 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4248 for (int icb = 0; icb < nb_ic_blocking; icb++)
4249 tileloadd(Tmm(get_wei_tensor(ocb, icb)),
4250 ptr[reg_ker + reg_dense_stride
4251 + get_full_kernel_offset(ocb, icb, 0, 0)]);
4252 jmp(ker_init_done);
4253 L(do_zero);
4254 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4255 for (int icb = 0; icb < nb_ic_blocking; icb++)
4256 tilezero(Tmm(get_wei_tensor(ocb, icb)));
4257
4258 L(ker_init_done);
4259
4260 mov(ptr[rsp + ddst_save_offset], reg_ddst);
4261 mov(ptr[rsp + src_save_offset], reg_src);
4262
4263 lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]);
4264 emit_h_loop();
4265
4266 mov(reg_ddst, ptr[rsp + ddst_save_offset]);
4267 mov(reg_src, ptr[rsp + src_save_offset]);
4268
4269 // The reg_ker ptr has highest bit set if the ddst needs to
4270 // be zeroed. Those who have byte-aligned their data will
4271 // suffer the consiquences :(
4272 mov(reg_tmp, reg_ker);
4273 and_(reg_ker, ~1);
4274
4275 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4276 for (int icb = 0; icb < nb_ic_blocking; icb++)
4277 tilestored(
4278 ptr[reg_ker + reg_dense_stride
4279 + get_full_kernel_offset(ocb, icb, 0, 0)],
4280 Tmm(get_wei_tensor(ocb, icb)));
4281
4282 mov(reg_ker, reg_tmp);
4283 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
4284 add(reg_kw, 1);
4285 cmp(reg_kw, jcp.kw);
4286 jl(kw_loop);
4287 }
4288
4289 sub(reg_src, reg_ihs);
4290 sub(reg_ddst, reg_ohs);
4291
4292 L(kh_loop_end);
4293 add(reg_kh, 1);
4294 cmp(reg_kh, jcp.kh);
4295 jl(kh_loop);
4296 }
4297 };
4298
4299 mov(reg_src, ptr[param + GET_OFF(src)]);
4300 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4301 mov(reg_ker, ptr[param + GET_OFF(filt)]);
4302 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4303 or_(reg_ker, reg_tmp);
4304
4305 bool single_kh_kw_loop = (h_last_block_size == jcp.oh);
4306
4307 auto src_row_step = get_src_offset(0, 0, 1);
4308 auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad);
4309 auto ddst_block_step = get_ddst_offset(0, h_block_size);
4310
4311 emit_kh_kw_loop(true, single_kh_kw_loop);
4312
4313 if (!single_kh_kw_loop) {
4314 auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh);
4315 sub(reg_ker, ker_reset_offset);
4316 and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
4317
4318 add(reg_src, first_src_block_step);
4319 add(reg_ddst, ddst_block_step);
4320
4321 int num_innermost_iters
4322 = (jcp.oh - h_last_block_size) / h_block_size - 1;
4323 if (num_innermost_iters > 0) {
4324 Label h_block_loop;
4325
4326 mov(reg_tmp_w, num_innermost_iters);
4327 kmovw(reg_h_block, reg_tmp_w);
4328 L(h_block_loop);
4329 {
4330 emit_kh_kw_loop(false, false);
4331 sub(reg_ker, ker_reset_offset);
4332 add(reg_src, src_row_step * h_block_size);
4333 add(reg_ddst, ddst_block_step);
4334
4335 kmovw(reg_tmp_w, reg_h_block);
4336 sub(reg_tmp_w, 1);
4337 kmovw(reg_h_block, reg_tmp_w);
4338 jnz(h_block_loop);
4339 }
4340 }
4341
4342 emit_kh_kw_loop(false, true);
4343 }
4344}
4345
4346void jit_avx512_core_amx_bwd_weights_kernel_t::compute_ic_loop(
4347 int ic_block, int nb_ic_blocking, int nb_oc_blocking) {
4348 assert(jcp.ur_w % 2 == 0);
4349 const int str_w = jcp.stride_w;
4350 assert(jcp.tr_iw % str_w == 0);
4351 const int src_stride_w_shift = jcp.tr_iw / str_w;
4352
4353 mov(reg_b_stride, 64);
4354 mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
4355
4356 for (int s = 0; s < str_w; s++) {
4357 for (int i_kw = s; i_kw < jcp.kw; i_kw += str_w) {
4358
4359 for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4360 for (int icb = 0; icb < nb_ic_blocking; icb++)
4361 tileloadd(Tmm(get_wei_tensor(ocb, icb)),
4362 ptr[reg_kernel + reg_b_stride
4363 + get_full_kernel_offset(
4364 ocb, icb, 0, i_kw)]);
4365
4366 int src_offset_l = (i_kw * (jcp.dilate_w + 1)) / str_w
4367 + s * src_stride_w_shift;
4368
4369 for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
4370 dim_t ur_w_src_offset = ur_w_b
4371 * get_src_offset(0, filter_w_to_src(0, jcp.ur_w, 0));
4372 dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
4373 for (int icb = 0; icb < nb_ic_blocking; icb++) {
4374 dim_t icb_offset = icb * jcp.tr_src_buf_size;
4375 tileloadd(Tmm(get_src_tensor(icb)),
4376 ptr[reg_src
4377 + jcp.typesize_in
4378 * (src_offset_l + icb_offset)
4379 + ur_w_src_offset + reg_a_stride]);
4380 }
4381 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4382 tileloadd(Tmm(get_ddst_tensor(ocb)),
4383 ptr[reg_ddst
4384 + jcp.typesize_in * ocb
4385 * jcp.tr_diff_dst_buf_size
4386 + ur_w_ddst_offset + reg_b_stride]);
4387 for (int icb = 0; icb < nb_ic_blocking; icb++)
4388 tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
4389 Tmm(get_src_tensor(icb)),
4390 Tmm(get_ddst_tensor(ocb)));
4391 }
4392 }
4393
4394 for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4395 for (int icb = 0; icb < nb_ic_blocking; icb++)
4396 tilestored(ptr[reg_kernel + reg_b_stride
4397 + get_full_kernel_offset(
4398 ocb, icb, 0, i_kw)],
4399 Tmm(get_wei_tensor(ocb, icb)));
4400 }
4401 }
4402 safe_add(reg_src, get_src_offset(ic_block, 0), reg_long_offt);
4403 add(reg_kernel, get_kernel_offset(ic_block, 0));
4404}
4405
4406void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_init(int ocb) {
4407 auto reg_unit_val = reg_tmp.cvt16();
4408 mov(reg_unit_val, 0x3f80); // bf16 value of 1.
4409 vpbroadcastw(vreg_bias_unit, reg_unit_val);
4410
4411 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4412 vmovups(vreg_bias_acc, ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block]);
4413}
4414
4415void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_row(
4416 bool is_partial, int ocb) {
4417 if (!jcp.with_bias) return;
4418 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4419 Label skip_label;
4420 test(reg_tmp, FLAG_IC_FIRST);
4421 jz(skip_label, T_NEAR);
4422
4423 if (is_partial) { compute_diff_bias_init(ocb); }
4424 auto compute_step = [&]() {
4425 vmovups(vreg_bias_ddst, ptr[reg_ddst]);
4426 vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
4427 };
4428
4429 Label ow_loop, ow_tail;
4430 int niters = jcp.tr_ow / 2;
4431 if (niters > 0) {
4432 mov(reg_tmp, jcp.tr_ow / 2);
4433 L(ow_loop);
4434 compute_step();
4435 add(reg_ddst, get_ddst_offset(2));
4436 sub(reg_tmp, 1);
4437 jnz(ow_loop, T_NEAR);
4438 }
4439 if (jcp.tr_ow % 2) compute_step();
4440
4441 if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters));
4442
4443 if (is_partial) {
4444 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4445 vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4446 vreg_bias_acc);
4447 }
4448
4449 L(skip_label);
4450}
4451
4452void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_compute_diff_bias(
4453 int nb_oc_blocking) {
4454 // In harness_3d_reduction case calculation of diff_bias is called
4455 // for every ow row separately to be aligned with od loop in
4456 // compute_od_loop_common()
4457 if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return;
4458 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4459
4460 Label skip_label;
4461 test(reg_tmp, FLAG_IC_FIRST);
4462 jz(skip_label, T_NEAR);
4463
4464 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4465 Label bias_loop, skip_label_local;
4466
4467 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4468 add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size);
4469
4470 switch (jcp.harness) {
4471 case harness_2d_reduction:
4472 mov(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4473 sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4474 break;
4475 case harness_mb_reduction:
4476 case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break;
4477 case harness_3d_reduction:
4478 default: assert(!"Invalid harness type");
4479 }
4480
4481 cmp(reg_oj, 0);
4482 jle(skip_label_local, T_NEAR); // nothing to do
4483
4484 compute_diff_bias_init(ocb);
4485 L(bias_loop);
4486 {
4487 compute_diff_bias_row(false, ocb);
4488 add(reg_ddst, get_ddst_offset(0, 1));
4489
4490 sub(reg_oj, 1);
4491 jnz(bias_loop, T_NEAR);
4492 }
4493
4494 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4495 vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4496 vreg_bias_acc);
4497
4498 L(skip_label_local);
4499 }
4500 // restore reg_ddst value
4501 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4502
4503 L(skip_label);
4504}
4505
4506void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_step_common(
4507 int nb_ic_blocking, int nb_oc_blocking) {
4508 Label kh_label, ic_block_label, ow_block_label, kd_label;
4509
4510 int ic_block = jcp.ic_block;
4511 int ic_tail = jcp.ic_tail;
4512
4513 auto ic_loop = [&](int nb_ic_blocking, int nb_oc_blocking) {
4514 Label ic_tail_label, ic_loop_done_label;
4515
4516 if (ic_tail) {
4517 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
4518 cmp(reg_icb, jcp.ic_tail);
4519 jne(ic_tail_label, T_NEAR);
4520
4521 compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4522 jmp(ic_loop_done_label, T_NEAR);
4523
4524 L(ic_tail_label);
4525 compute_ic_loop(ic_tail, nb_ic_blocking, nb_oc_blocking);
4526 add(reg_kernel, get_kernel_offset(jcp.ic_block - ic_tail, 0));
4527 safe_add(reg_src,
4528 get_src_offset(0, 0, filter_h_to_src(1))
4529 - get_src_offset(ic_tail, 0),
4530 reg_long_offt);
4531 L(ic_loop_done_label);
4532 } else {
4533 compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4534 }
4535 };
4536
4537 if (jcp.ndims == 5) {
4538 /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of
4539 * 'movs' must be guaranteed. */
4540 mov(ki, reg_kd_count);
4541 mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count);
4542 mov(aux_reg_src, reg_src);
4543 mov(aux_reg_kernel, reg_kernel);
4544
4545 L(kd_label);
4546 mov(reg_src, aux_reg_src);
4547 mov(reg_kernel, aux_reg_kernel);
4548 }
4549
4550 mov(kj, reg_kh);
4551 L(kh_label);
4552 {
4553 ic_loop(nb_ic_blocking, nb_oc_blocking);
4554
4555 if (jcp.dilate_h > 0) {
4556 add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
4557 }
4558 // substract pointer shift made within ic block loop
4559 // and move to next kh index
4560 add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
4561 dec(kj);
4562 cmp(kj, 0);
4563 jg(kh_label, T_NEAR);
4564 }
4565 if (jcp.ndims == 5) {
4566 add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
4567 add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
4568 dec(ki);
4569 cmp(ki, 0);
4570 jg(kd_label, T_NEAR);
4571 }
4572 // In harness_3d_reduction case calculation of diff_bias is called
4573 // for every ow row separately to be aligned with od loop in
4574 // compute_od_loop_common()
4575 if (jcp.harness == harness_3d_reduction) {
4576 auto reg_save_ddst = reg_a_stride;
4577 mov(reg_save_ddst, reg_ddst);
4578 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4579 safe_add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size,
4580 reg_long_offt);
4581 compute_diff_bias_row(true, ocb);
4582 }
4583 mov(reg_ddst, reg_save_ddst);
4584 }
4585
4586 if (jcp.ndims == 5) {
4587 mov(reg_src, aux_reg_src);
4588 mov(reg_kernel, aux_reg_kernel);
4589 mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset));
4590 od_step_comeback_pointers();
4591 } else {
4592 oh_step_comeback_pointers();
4593 }
4594}
4595
4596void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_zero_kernel(
4597 int nb_ic_blocking, int nb_oc_blocking) {
4598 if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return;
4599 Label skip_zeroing, zeroing_loop;
4600
4601 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4602 cmp(reg_tmp, 0);
4603 jz(skip_zeroing, T_NEAR);
4604
4605 Zmm zero = Zmm(0);
4606 vpxord(zero, zero, zero);
4607 if (jcp.with_bias) {
4608 Label skip_bias_zeroing;
4609 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4610 test(reg_tmp, FLAG_IC_FIRST);
4611 jz(skip_bias_zeroing, T_NEAR);
4612 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4613 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4614 vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block], zero);
4615 }
4616 L(skip_bias_zeroing);
4617 if (jcp.harness == harness_compute_full_spatial)
4618 jmp(skip_zeroing, T_NEAR);
4619 }
4620
4621 mov(reg_b_stride, 64);
4622 tilezero(Tmm(get_wei_tensor(0, 0)));
4623 for (dim_t shift = 0;
4624 shift < get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd);
4625 shift += get_kernel_offset(jcp.ic_block, 0)) {
4626 for_(int icb = 0; icb < nb_ic_blocking; icb++)
4627 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4628 tilestored(
4629 ptr[reg_kernel + reg_b_stride
4630 + get_full_kernel_offset(ocb, icb, 0, 0) + shift],
4631 Tmm(get_wei_tensor(0, 0)));
4632 }
4633 }
4634 L(skip_zeroing);
4635}
4636
4637void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_loop_common(
4638 int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4639 int b_pad = jcp.b_pad;
4640 int t_pad = jcp.t_pad;
4641
4642 bool is_dilated = jcp.dilate_h != 0;
4643 int dilate_h = jcp.dilate_h + 1;
4644 int stride_h = jcp.stride_h;
4645 auto filter_step_size = get_kernel_offset(0, jcp.kw);
4646 auto src_step_size = get_src_offset(0, 0, 1);
4647 auto ddst_step_size = get_ddst_offset(0, 1);
4648 Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end,
4649 oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label,
4650 oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift,
4651 oh_dilate_label_end, oh_dilate_setup_label_shift,
4652 oh_dilate_setup_label_noshift, oh_dilate_setup_label_end;
4653
4654 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4655 int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h);
4656 int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end);
4657 int oh_head_overflow_end = div_up(t_pad, stride_h);
4658 int oh_tail_end = jcp.oh;
4659
4660 int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h;
4661 int ih_body_end
4662 = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset);
4663
4664 if (is_partial)
4665 mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4666 else
4667 xor_(reg_oj, reg_oj);
4668
4669 /* Compute 'top' edge */
4670 if (t_pad > 0) {
4671 if (is_partial) {
4672 cmp(reg_oj, oh_head_overflow_end);
4673 jge(oh_tpad_tail_label_end, T_NEAR);
4674 }
4675 const int overflow
4676 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
4677 const int underflow = div_up(t_pad, dilate_h);
4678 const int initial_kh = jcp.kh - overflow - underflow;
4679
4680 // Setup reg_kh, reg_kernel, and reg_src
4681 mov(reg_kh, initial_kh);
4682 add(reg_kernel, filter_step_size * underflow);
4683 if (is_dilated) {
4684 const int tail = t_pad % dilate_h;
4685 const int shift = tail == 0 ? 0 : dilate_h - tail;
4686 mov(reg_ih_shift, shift);
4687 if (!is_partial) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4688 add(reg_src, src_step_size * shift);
4689 }
4690
4691 if (is_partial) {
4692 Label head_setup, head_setup_finish;
4693 cmp(reg_oj, 0);
4694 je(head_setup_finish, T_NEAR);
4695 mov(reg_oj_setup, reg_oj);
4696
4697 L(head_setup);
4698 if (is_dilated) {
4699 inc(reg_ih_shift);
4700 cmp(reg_ih_shift, dilate_h);
4701 jl(oh_dilate_setup_label_shift, T_NEAR);
4702 // unshift src as new kernel element enters
4703 sub(reg_src, src_step_size * (dilate_h - 1));
4704 xor_(reg_ih_shift, reg_ih_shift);
4705 }
4706 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4707 add(reg_kh, stride_h);
4708 sub(reg_kernel, filter_step_size * stride_h);
4709 if (is_dilated) {
4710 jmp(oh_dilate_setup_label_noshift, T_NEAR);
4711 L(oh_dilate_setup_label_shift);
4712 // shift src as old kernel element progresses
4713 add(reg_src, src_step_size * stride_h);
4714 L(oh_dilate_setup_label_noshift);
4715 }
4716 sub(reg_oj_setup, 1);
4717 jg(head_setup, T_NEAR);
4718 L(head_setup_finish);
4719
4720 if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4721 if (oh_head_end < oh_head_overflow_end) {
4722 cmp(reg_oj, oh_head_end);
4723 jge(oh_tpad_label_end, T_NEAR);
4724 }
4725 }
4726
4727 //Setup reg_kernel
4728 // If dilated, shift src ptr
4729 // Loop
4730 L(oh_tpad_label);
4731 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4732 add(reg_ddst, ddst_step_size);
4733 if (is_dilated) {
4734 mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4735 inc(reg_ih_shift);
4736 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4737 cmp(reg_ih_shift, dilate_h);
4738 jl(oh_dilate_label_shift, T_NEAR);
4739 // unshift src as new kernel element enters
4740 sub(reg_src, src_step_size * (dilate_h - 1));
4741 xor_(reg_ih_shift, reg_ih_shift);
4742 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4743 }
4744 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4745 add(reg_kh, stride_h);
4746 sub(reg_kernel, filter_step_size * stride_h);
4747 if (is_dilated) {
4748 jmp(oh_dilate_label_noshift, T_NEAR);
4749 L(oh_dilate_label_shift);
4750 // shift src as old kernel element progresses
4751 add(reg_src, src_step_size * stride_h);
4752 L(oh_dilate_label_noshift);
4753 }
4754 inc(reg_oj);
4755
4756 if (is_partial) {
4757 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4758 jge(oh_bpad_label_end, T_NEAR);
4759 }
4760 cmp(reg_oj, oh_head_end);
4761 jl(oh_tpad_label, T_NEAR);
4762
4763 L(oh_tpad_label_end);
4764 // need second loop to process kernel if it is larger than the src
4765 // (does not apply to dilations as they must have unit stride)
4766 if (oh_head_end < oh_head_overflow_end) {
4767 assert(!is_dilated);
4768
4769 cmp(reg_oj, oh_head_overflow_end);
4770 jge(oh_tpad_tail_label_end, T_NEAR);
4771
4772 mov(reg_kh, jcp.ih);
4773 L(oh_tpad_tail_label);
4774 {
4775 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4776 add(reg_ddst, ddst_step_size);
4777 sub(reg_kernel, filter_step_size * stride_h);
4778
4779 inc(reg_oj);
4780
4781 if (is_partial) {
4782 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4783 jge(oh_bpad_label_end, T_NEAR);
4784 }
4785 cmp(reg_oj, oh_head_overflow_end);
4786 jl(oh_tpad_tail_label, T_NEAR);
4787 }
4788 }
4789 if (body_src_start_offset != 0) {
4790 add(reg_kernel, filter_step_size * body_src_start_offset);
4791 add(reg_src, src_step_size * body_src_start_offset);
4792 }
4793 L(oh_tpad_tail_label_end);
4794 }
4795
4796 if (is_partial) {
4797 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4798 jge(oh_bpad_label_end, T_NEAR);
4799 }
4800 cmp(reg_oj, oh_body_end);
4801 jge(oh_label_end, T_NEAR);
4802
4803 /* Compute middle block(s) */
4804 mov(reg_kh, jcp.kh);
4805 L(oh_label);
4806 {
4807 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4808 add(reg_src, src_step_size * stride_h);
4809 add(reg_ddst, ddst_step_size);
4810
4811 inc(reg_oj);
4812
4813 if (is_partial) {
4814 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4815 jge(oh_bpad_label_end, T_NEAR);
4816 }
4817
4818 cmp(reg_oj, oh_body_end);
4819 jl(oh_label, T_NEAR);
4820 }
4821 L(oh_label_end);
4822
4823 /* Compute bottom edge */
4824 if (b_pad > 0) {
4825 if (is_partial) {
4826 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4827 jge(oh_bpad_label_end, T_NEAR);
4828 }
4829 cmp(reg_oj, jcp.oh);
4830 jge(oh_bpad_label_end, T_NEAR);
4831
4832 if (is_dilated) {
4833 // Assumes unit stride for dilations
4834 mov(reg_kh, jcp.kh - 1);
4835 xor_(reg_ih_shift, reg_ih_shift);
4836 } else {
4837 assert(jcp.dilate_h == 0);
4838 mov(reg_kh, jcp.ih - ih_body_end);
4839 }
4840 if (is_partial) {
4841 lea(reg_oj_setup,
4842 ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]);
4843 if (stride_h == 1 && !is_dilated) {
4844 sub(reg_kh, reg_oj_setup);
4845 } else {
4846 Label body_setup, body_setup_finish, dilate_skip;
4847 cmp(reg_oj_setup, 0);
4848 je(body_setup_finish, T_NEAR);
4849
4850 L(body_setup);
4851 if (is_dilated) {
4852 inc(reg_ih_shift);
4853 cmp(reg_ih_shift, dilate_h);
4854 jl(dilate_skip, T_NEAR);
4855 xor_(reg_ih_shift, reg_ih_shift);
4856 }
4857 sub(reg_kh, stride_h);
4858 L(dilate_skip);
4859 sub(reg_oj_setup, 1);
4860 jg(body_setup, T_NEAR);
4861 L(body_setup_finish);
4862 }
4863 }
4864
4865 if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4866 L(oh_bpad_label);
4867 {
4868 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4869 add(reg_src, src_step_size * stride_h);
4870 add(reg_ddst, ddst_step_size);
4871
4872 if (is_dilated) {
4873 mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4874 inc(reg_ih_shift);
4875 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4876 cmp(reg_ih_shift, dilate_h);
4877 jl(oh_dilate_label_end, T_NEAR);
4878 xor_(reg_ih_shift, reg_ih_shift);
4879 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4880 }
4881 sub(reg_kh, stride_h);
4882 L(oh_dilate_label_end);
4883 inc(reg_oj);
4884 if (is_partial) {
4885 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4886 jge(oh_bpad_label_end, T_NEAR);
4887 }
4888 cmp(reg_oj, oh_tail_end);
4889 jl(oh_bpad_label, T_NEAR);
4890 }
4891 }
4892 L(oh_bpad_label_end);
4893}
4894
4895void jit_avx512_core_amx_bwd_weights_kernel_t::compute_od_loop_common(
4896 int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4897 assert(jcp.harness == harness_3d_reduction);
4898
4899 const int src_backpad_overlap
4900 = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
4901
4902 const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw);
4903 const auto src_shift = get_src_offset(0, 0, jcp.ih);
4904 const auto ddst_shift = get_ddst_offset(0, jcp.oh);
4905
4906 const int kd_front_pad = nstl::max(0, jcp.f_pad);
4907 const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id);
4908
4909 Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
4910 backpad_end_label, backpad_label;
4911
4912 /* initially offset 'kd' by f_pad */
4913 mov(reg_src_d, ptr[param + GET_OFF(src)]);
4914 mov(reg_ddst_d, ptr[param + GET_OFF(dst)]);
4915
4916 if (is_partial) {
4917 add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
4918 mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]);
4919 mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
4920 } else {
4921 const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad;
4922 const int kd_offset = get_kernel_offset(
4923 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw);
4924 add(reg_kernel, kd_offset);
4925 xor_(reg_d_index, reg_d_index);
4926 mov(reg_kd_count, kd_padding);
4927 }
4928
4929 cmp(reg_kd_count, 0);
4930 jle(loop_end_label, T_NEAR); // no iterations along kd
4931 if (is_partial)
4932 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
4933 else
4934 cmp(reg_d_index, jcp.od);
4935 jge(loop_end_label, T_NEAR); // no iterations along depth dimension
4936
4937 L(d_loop_label);
4938
4939 mov(reg_src, reg_src_d);
4940 mov(reg_ddst, reg_ddst_d);
4941
4942 mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d);
4943 mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d);
4944 mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index);
4945
4946 compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
4947
4948 mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset));
4949 mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset));
4950 mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset));
4951
4952 /* Compute 'front' edge */
4953 if (jcp.f_pad > 0) {
4954 /* Check if within fpad region */
4955 cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
4956 jge(fpad_end_label, T_NEAR);
4957
4958 /* Fpad steps */
4959 sub(reg_kernel, filter_shift * jcp.stride_d);
4960 add(reg_kd_count, jcp.stride_d);
4961
4962 /* Final number of kernel elements that overlap with src */
4963 const int src_ker_overlap = nstl::min(jcp.kd, jcp.id);
4964 cmp(reg_kd_count, src_ker_overlap);
4965 jle(common_block_label, T_NEAR);
4966
4967 /* Correct any excess shifts to kernel and src */
4968 if (jcp.f_pad <= jcp.od * jcp.stride_d) {
4969 /* Filter has moved beyond padding (adjust for stride effects) */
4970 if (jcp.f_pad % jcp.stride_d != 0) {
4971 int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
4972 add(reg_kernel, filter_shift * src_corr);
4973 add(reg_src_d, src_shift * src_corr);
4974 }
4975 } else {
4976 /* Filter still overlaps padding (complete reset) */
4977 sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
4978 }
4979
4980 /* Apply correction */
4981 mov(reg_kd_count, src_ker_overlap);
4982 jmp(common_block_label);
4983
4984 L(fpad_end_label);
4985 }
4986
4987 /* Compute bottom edge */
4988 if (jcp.back_pad > 0) {
4989
4990 /* Check if within back_pad region */
4991 cmp(reg_d_index, src_backpad_overlap - 1);
4992 jl(backpad_end_label, T_NEAR);
4993 jg(backpad_label, T_NEAR);
4994
4995 /* Execute overlap correction between the filter and the initial
4996 * back_pad region. */
4997 mov(reg_kd_count,
4998 jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d);
4999 jmp(backpad_end_label, T_NEAR);
5000
5001 L(backpad_label);
5002 sub(reg_kd_count, jcp.stride_d);
5003 cmp(reg_kd_count, 0);
5004 jle(loop_end_label, T_NEAR);
5005
5006 L(backpad_end_label);
5007 }
5008
5009 /* Compute middle block */
5010 add(reg_src_d, src_shift * jcp.stride_d);
5011
5012 /* Execute common block and loop */
5013 L(common_block_label);
5014 add(reg_ddst_d, ddst_shift);
5015 inc(reg_d_index);
5016 if (is_partial)
5017 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
5018 else
5019 cmp(reg_d_index, jcp.od);
5020 jl(d_loop_label, T_NEAR);
5021
5022 L(loop_end_label);
5023}
5024
5025void jit_avx512_core_amx_bwd_weights_kernel_t::compute_loop(
5026 int nb_ic_blocking, int nb_oc_blocking) {
5027 mov(reg_src, ptr[param + GET_OFF(src)]);
5028 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
5029 mov(reg_kernel, ptr[param + GET_OFF(filt)]);
5030
5031 maybe_zero_kernel(nb_ic_blocking, nb_oc_blocking);
5032 maybe_compute_diff_bias(nb_oc_blocking);
5033
5034 switch (jcp.harness) {
5035 case harness_3d_reduction:
5036 compute_od_loop_common(nb_ic_blocking, nb_oc_blocking, true);
5037 break;
5038 case harness_2d_reduction:
5039 compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking, true);
5040 break;
5041 case harness_mb_reduction:
5042 compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
5043 break;
5044 case harness_compute_full_spatial:
5045 compute_full_spat_loop(nb_ic_blocking, nb_oc_blocking);
5046 break;
5047 default: assert(!"Invalid harness type");
5048 }
5049}
5050
5051void jit_avx512_core_amx_bwd_weights_kernel_t::setup_stack_space() {
5052 kd_count_offset = ic_block_step_stack_size;
5053 src_d_offset = ic_block_step_stack_size + 8;
5054 ddst_d_offset = ic_block_step_stack_size + 16;
5055 d_index_offset = ic_block_step_stack_size + 24;
5056 ih_dilate_offset = ic_block_step_stack_size + 32;
5057 src_save_offset = ic_block_step_stack_size + 40;
5058 ddst_save_offset = ic_block_step_stack_size + 48;
5059 stack_space_needed = ic_block_step_stack_size + 56;
5060}
5061
5062void jit_avx512_core_amx_bwd_weights_kernel_t::generate() {
5063 preamble();
5064
5065 setup_stack_space();
5066
5067 sub(rsp, stack_space_needed);
5068
5069 Label last_ic_block_label, last_blocks_done_label;
5070
5071 mov(reg_tmp, ptr[param + GET_OFF(last_ic_block)]);
5072 cmp(reg_tmp, 0);
5073 jne(last_ic_block_label, T_NEAR);
5074 { // full nb_ic_blocking
5075 Label last_oc_block_label;
5076 mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
5077 cmp(reg_tmp, 0);
5078 jne(last_oc_block_label, T_NEAR);
5079 { // full nb_oc_blocking
5080 compute_loop(jcp.nb_ic_blocking, jcp.nb_oc_blocking);
5081 jmp(last_blocks_done_label, T_NEAR);
5082 }
5083 L(last_oc_block_label);
5084 { // tail of nb_oc_blocking
5085 compute_loop(jcp.nb_ic_blocking, 1);
5086 jmp(last_blocks_done_label, T_NEAR);
5087 }
5088 }
5089 L(last_ic_block_label);
5090 { // tail nb_ic_blocking
5091 Label last_oc_block_label;
5092 mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
5093 cmp(reg_tmp, 0);
5094 jne(last_oc_block_label, T_NEAR);
5095 { // full nb_oc_blocking
5096 compute_loop(1, jcp.nb_oc_blocking);
5097 jmp(last_blocks_done_label, T_NEAR);
5098 }
5099 L(last_oc_block_label);
5100 { // tail of nb_oc_blocking
5101 compute_loop(1, 1);
5102 jmp(last_blocks_done_label, T_NEAR);
5103 }
5104 }
5105
5106 L(last_blocks_done_label);
5107 add(rsp, stack_space_needed);
5108
5109 postamble();
5110}
5111
5112status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_conf(
5113 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
5114 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
5115 memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
5116 const memory_desc_wrapper src_d(&src_md);
5117 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
5118 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
5119 const memory_desc_wrapper diff_bias_d(&diff_bias_md);
5120
5121 jcp = zero<decltype(jcp)>();
5122
5123 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
5124 int ndims = src_d.ndims();
5125
5126 if (!mayiuse(avx512_core_amx)) return status::unimplemented;
5127 jcp.isa = avx512_core_amx;
5128
5129 jcp.has_vnni = true; // Needed for transpose routines
5130 jcp.nthr = nthreads;
5131
5132 jcp.ndims = ndims;
5133 jcp.prop_kind = cd.prop_kind;
5134
5135 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
5136 jcp.mb = src_d.dims()[0];
5137
5138 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
5139 jcp.oc_without_padding = jcp.oc;
5140 jcp.ic = src_d.dims()[1] / jcp.ngroups;
5141
5142 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
5143 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
5144 jcp.iw = src_d.dims()[ndims - 1];
5145 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
5146 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
5147 jcp.ow = diff_dst_d.dims()[ndims - 1];
5148
5149 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
5150 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
5151 jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
5152
5153 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
5154 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
5155 jcp.l_pad = cd.padding[0][ndims - 3];
5156
5157 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
5158 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
5159 jcp.stride_w = cd.strides[ndims - 3];
5160
5161 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
5162 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
5163 jcp.dilate_w = cd.dilates[ndims - 3];
5164
5165 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
5166 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
5167 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
5168
5169 bool ok = true
5170 // general condition to simplify dilations
5171 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
5172 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
5173 && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
5174 // special condition to simplify dilations in compute_oh_loop_common
5175 && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih);
5176 if (!ok) return status::unimplemented;
5177
5178 ok = true && one_of(ndims, 3, 4, 5)
5179 && everyone_is(
5180 data_type::bf16, src_d.data_type(), diff_dst_d.data_type())
5181 && one_of(diff_weights_d.data_type(), data_type::f32,
5182 data_type::bf16);
5183 if (!ok) return status::unimplemented;
5184
5185 jcp.transform_to_vnni = diff_weights_d.data_type() == data_type::bf16;
5186
5187 jcp.r_pad = nstl::max(0,
5188 calculate_end_padding(
5189 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
5190 jcp.b_pad = nstl::max(0,
5191 calculate_end_padding(
5192 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
5193 jcp.back_pad = nstl::max(0,
5194 calculate_end_padding(
5195 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
5196
5197 /* XXX: no support for padding when dilation_d > 0 */
5198 if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad)))
5199 return status::unimplemented;
5200
5201 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
5202 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
5203 jcp.ohp = jcp.oh;
5204 jcp.owp = jcp.ow;
5205
5206 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
5207 if (jcp.is_depthwise)
5208 return status::unimplemented; // TODO: add support of DW convolution
5209
5210 const int dat_format_tag = ndims - 3;
5211 format_tag_t dat_tag_nspc = utils::pick(dat_format_tag, format_tag::nwc,
5212 format_tag::nhwc, format_tag::ndhwc);
5213 format_tag_t dat_tag_opt = dat_tag_nspc;
5214
5215 if (src_d.format_kind() == format_kind::any) {
5216 CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
5217 jcp.src_tag = dat_tag_opt;
5218 } else
5219 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_opt);
5220 if (!one_of(jcp.src_tag, dat_tag_opt)) return status::unimplemented;
5221 jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
5222
5223 if (diff_dst_d.format_kind() == format_kind::any) {
5224 CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
5225 jcp.dst_tag = jcp.src_tag;
5226 } else
5227 jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
5228 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
5229
5230 if (!jcp.is_nspc) return status::unimplemented;
5231
5232 const int wei_format_tag = 2 * ndims - 6 + with_groups;
5233 format_tag_t wei_tag;
5234 if (jcp.transform_to_vnni)
5235 wei_tag = pick(wei_format_tag, format_tag::OIw16i16o2i,
5236 format_tag::gOIw16i16o2i, format_tag::OIhw16i16o2i,
5237 format_tag::gOIhw16i16o2i, format_tag::OIdhw16i16o2i,
5238 format_tag::gOIdhw16i16o2i);
5239 else
5240 wei_tag = pick(wei_format_tag, format_tag::OIw16i16o,
5241 format_tag::gOIw16i16o, format_tag::OIhw16i16o,
5242 format_tag::gOIhw16i16o, format_tag::OIdhw16i16o,
5243 format_tag::gOIdhw16i16o);
5244 if (diff_weights_md.format_kind == format_kind::any) {
5245 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
5246 jcp.wei_tag = wei_tag;
5247 } else {
5248 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
5249 if (jcp.wei_tag != wei_tag) return status::unimplemented;
5250 }
5251 jcp.wei_dt = diff_weights_d.data_type();
5252
5253 /* conditions on bias memory */
5254 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
5255 if (jcp.with_bias) {
5256 if (diff_bias_d.format_kind() == format_kind::any)
5257 CHECK(memory_desc_init_by_tag(diff_bias_md, format_tag::x));
5258 }
5259 jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef;
5260 jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
5261
5262 /* kernel applicability check wrt boundaries
5263 * the conditions are quite general across the kernels we have,
5264 * but ideally the check should belong to a specific kernel... */
5265 const int max_pad_h = ext_kh / 2;
5266 const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw
5267 && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h
5268 && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd;
5269 if (!boundaries_ok) return status::unimplemented;
5270
5271 jcp.ic_block = 16;
5272 jcp.oc_block = 16;
5273
5274 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
5275 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
5276
5277 jcp.ic_tail = jcp.ic % jcp.ic_block;
5278 jcp.oc_tail = jcp.oc % jcp.oc_block;
5279
5280 jcp.nb_oc_blocking = (jcp.nb_oc > 1) ? 2 : 1;
5281 jcp.nb_ic_blocking = (jcp.nb_ic > 1) ? 2 : 1;
5282
5283 const int target_palette = amx::get_target_palette();
5284 jcp.max_tiles = amx::get_max_tiles(target_palette);
5285 jcp.full_tile_width = amx::get_max_rows(target_palette);
5286
5287 if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
5288 return status::unimplemented;
5289
5290 const bool is_2d = (ndims == 4);
5291 const bool is_3d = (ndims == 5);
5292 jcp.typesize_in = sizeof(bfloat16_t);
5293 jcp.typesize_out = sizeof(float);
5294
5295 // TODO: Find more shapes (especially 3D with large spatials) for which
5296 // local transposition will be beneficial. Furthermore, for TBB threads
5297 // more shapes can potentially benefit from spatial blocking
5298 int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow;
5299
5300 jcp.global_transpose = dnnl_thr_syncable();
5301 jcp.spatial_blk_size = optimal_blk_size;
5302
5303 const int tr_round = 32; // To load full tile register
5304 int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round);
5305 jcp.tr_iw = rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
5306 * jcp.stride_w;
5307
5308 jcp.tr_src_num_guard_elems = tr_pad; // upper bound
5309 jcp.tr_ow = rnd_up(jcp.ow, 2);
5310
5311 if (jcp.tr_ow <= max_ur_w) {
5312 jcp.ur_w = jcp.tr_ow;
5313 jcp.ur_w_blocks = 1;
5314 } else {
5315 jcp.ur_w = 1;
5316 for (int i = max_ur_w; i >= 1; i -= 2) {
5317 if (jcp.tr_ow % i == 0) {
5318 jcp.ur_w = i;
5319 break;
5320 }
5321 }
5322 jcp.ur_w_blocks = jcp.tr_ow / jcp.ur_w;
5323 }
5324
5325 bool args_ok = true && jcp.ic <= src_d.padded_dims()[1]
5326 && jcp.oc <= diff_dst_d.padded_dims()[1]
5327 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
5328 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
5329 if (!args_ok) return status::unimplemented;
5330
5331 bool use_full_spat_loop = jcp.ndims < 5 && jcp.ih == jcp.oh
5332 && jcp.iw == jcp.ow && everyone_is(1, jcp.stride_h, jcp.stride_w)
5333 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
5334 // TODO: Remove this constraint: only 3x3 kernel works now
5335 && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2
5336 && one_of(1, jcp.l_pad, jcp.r_pad) && jcp.kh == jcp.kw
5337 && jcp.ih >= jcp.kh && jcp.iw >= jcp.kw;
5338
5339 jcp.harness = ndims == 5
5340 ? harness_3d_reduction
5341 : (use_full_spat_loop ? harness_compute_full_spatial
5342 : (ndims == 4) ? harness_2d_reduction
5343 : harness_mb_reduction);
5344 switch (jcp.harness) {
5345 case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break;
5346 case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break;
5347 case harness_compute_full_spatial:
5348 case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break;
5349 default: assert(!"Invalid harness"); jcp.nthr_mb_work = jcp.mb;
5350 }
5351 { // balancing
5352 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
5353 balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
5354 jcp.nthr = nthr;
5355 jcp.nthr_mb = nthr_mb;
5356 jcp.nthr_g = nthr_g;
5357 jcp.nthr_oc_b = nthr_oc_b;
5358 jcp.nthr_ic_b = nthr_ic_b;
5359
5360 // TODO: Optimize memory allocation when threaded on height and depth
5361 jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
5362 jcp.tr_src_buf_count = jcp.global_transpose
5363 ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
5364 : jcp.nthr;
5365
5366 jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
5367 jcp.tr_diff_dst_buf_count = jcp.global_transpose
5368 ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
5369 : jcp.nthr;
5370 }
5371
5372 return status::success;
5373}
5374
5375status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_scratchpad(
5376 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
5377 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
5378 memory_desc_t &diff_dst_md) {
5379 const memory_desc_wrapper src_d(&src_md);
5380 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
5381 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
5382
5383 // XXX: See the comment about tr_iw and guarding elements in
5384 // jit_avx512_core_amx_bwd_weights_kernel_t::init_conf()
5385 const size_t tr_src_size
5386 = (jcp.tr_src_buf_count * jcp.tr_src_buf_size * jcp.nb_ic_blocking)
5387 + jcp.tr_src_num_guard_elems;
5388 scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in);
5389
5390 /* prepare synchronization contexts */
5391 if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
5392 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
5393 scratchpad.book<simple_barrier::ctx_t>(
5394 key_conv_tr_src_bctx, tr_src_bctx_size);
5395 }
5396
5397 const size_t tr_diff_dst_size = jcp.tr_diff_dst_buf_count
5398 * jcp.tr_diff_dst_buf_size * jcp.nb_oc_blocking;
5399
5400 const size_t min_align = 64;
5401 scratchpad.book(
5402 key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in, min_align);
5403
5404 /* prepare synchronization contexts */
5405 if (jcp.global_transpose && jcp.nthr_ic_b > 1) {
5406 const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
5407 scratchpad.book<simple_barrier::ctx_t>(
5408 key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size);
5409 }
5410
5411 if (IMPLICATION(jcp.nthr_mb == 1,
5412 (jcp.with_bias && jcp.bia_dt == data_type::bf16)
5413 || jcp.wei_dt == data_type::bf16)) {
5414 const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
5415 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
5416 const size_t bia_size
5417 = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block;
5418
5419 const int num_wei_buffers
5420 = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
5421 const int num_bia_buffers = jcp.with_bias
5422 ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb
5423 : jcp.nthr_mb - 1)
5424 : 0;
5425
5426 const size_t wei_bia_reduction_size
5427 = wei_size * num_wei_buffers + bia_size * num_bia_buffers;
5428
5429 scratchpad.book<float>(
5430 key_conv_wei_bia_reduction, wei_bia_reduction_size);
5431
5432 scratchpad.book<simple_barrier::ctx_t>(
5433 key_conv_wei_bia_reduction_bctx, 1);
5434 }
5435
5436 if (jcp.with_bias
5437 && ((jcp.oc_without_padding % jcp.oc_block != 0)
5438 && jcp.bia_dt == data_type::f32)) {
5439 scratchpad.book(key_conv_padded_bias,
5440 jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia);
5441 }
5442 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
5443
5444 constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32
5445 << 30; // 32Gb - TODO: may it's too large?
5446 const size_t scratchpad_limit_by_tensor_sizes = (size_t)32 * jcp.nthr
5447 * (src_d.size() + diff_weights_d.size() + diff_dst_d.size());
5448 const size_t scratchpad_limit
5449 = nstl::min(scratchpad_limit_by_absolute_value,
5450 scratchpad_limit_by_tensor_sizes);
5451 if (scratchpad.size() > scratchpad_limit)
5452 return status::unimplemented;
5453 else
5454 return status::success;
5455}
5456
5457void jit_avx512_core_amx_bwd_weights_kernel_t::balance(const jit_conv_conf_t &j,
5458 int &nthr_, int &nthr_mb_, int &nthr_g_, int &nthr_oc_b_,
5459 int &nthr_ic_b_) {
5460 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
5461
5462 const int max_threads = dnnl_get_max_threads();
5463
5464 if (max_threads < j.ngroups) {
5465 /* simplification... fortunately it doesn't hurt much */
5466 nthr_ = nthr_g_ = max_threads;
5467 return;
5468 }
5469
5470 nthr_g_ = j.ngroups;
5471 const int nthr = max_threads / nthr_g_;
5472
5473 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
5474 /* calculate per thread memory cost (read/write). high level optimizer
5475 * tries to minimize memory consumption. few notes:
5476 * (n1) if weights tensor size is less than source and destination
5477 * tensors we apply the ratio of the source and destination
5478 * tensor sizes to weights one as compensation coefficient to
5479 * avoid parallelization across batch size only, othervise we
5480 * apply additional coefficient to source component based on
5481 * performance measurements
5482 * (n2) use scales based on output vs input channels ratio for source
5483 * and destination componets to imporve threading balance across
5484 * input and output channels */
5485
5486 const dim_t src_type_size = 2;
5487 const dim_t wei_type_size = 4;
5488
5489 dim_t src_size
5490 = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size;
5491 dim_t dst_size
5492 = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size;
5493 dim_t wei_size
5494 = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size;
5495
5496 float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
5497 float oi_channels_ratio = (float)(j.nb_oc / j.nb_oc_blocking)
5498 / (j.nb_ic / j.nb_ic_blocking);
5499 auto get_src_coef = [=]() {
5500 float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
5501 if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
5502
5503 return src_coef;
5504 };
5505
5506 auto get_dst_coef
5507 = [=]() { return nstl::max(oi_channels_ratio, 1.0f); };
5508
5509 auto get_wei_coef
5510 = [=]() { return nstl::max(wei_compensation_scale, 1.0f); };
5511
5512 const float src_coef = get_src_coef();
5513 const float dst_coef = get_dst_coef();
5514 const float wei_coef = get_wei_coef();
5515
5516 float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb)
5517 * div_up(j.ngroups, nthr_g_)
5518 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.mb
5519 * (j.ic_block * j.nb_ic_blocking) * j.id * j.ih * j.tr_iw
5520 / j.nthr_mb_work / j.stride_d / j.stride_h / j.stride_w;
5521 float wei_v = wei_coef * div_up(j.ngroups, nthr_g_)
5522 * div_up((j.nb_oc / j.nb_oc_blocking),
5523 (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5524 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.kh * j.kw
5525 * j.kd * (j.ic_block * j.nb_ic_blocking)
5526 * (j.oc_block * j.nb_oc_blocking);
5527 float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb)
5528 * div_up(j.ngroups, nthr_g_)
5529 * div_up((j.nb_oc / j.nb_oc_blocking),
5530 (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5531 * j.mb * (j.oc_block * j.nb_oc_blocking) * j.od * j.oh * j.tr_ow
5532 / j.nthr_mb_work;
5533
5534 return src_v + dst_v + wei_v;
5535 };
5536
5537 float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
5538
5539 /* find the best thread distribution with lowest memory cost */
5540
5541 const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work);
5542 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
5543 const int nthr_par = nthr / nthr_mb;
5544 const int nthr_oc_b_max = nstl::min(nthr_par,
5545 (j.nb_oc / j.nb_oc_blocking)); // Amount of nb_oc_blocks
5546 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
5547 int nthr_ic_b = nstl::min(
5548 nthr_par / nthr_oc_b, (j.nb_ic / j.nb_ic_blocking));
5549
5550 float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
5551 if (mem_cost <= best_mem_cost) {
5552 best_mem_cost = mem_cost;
5553 nthr_mb_ = nthr_mb;
5554 nthr_oc_b_ = nthr_oc_b;
5555 nthr_ic_b_ = nthr_ic_b;
5556 }
5557 }
5558 }
5559
5560 if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr)
5561 nthr_mb_ = nstl::min(j.nthr_mb_work, nthr);
5562 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
5563
5564 assert(nthr_ <= max_threads);
5565}
5566
5567// start of diff_bias kernel
5568void jit_avx512_core_amx_bwd_bias_kernel_t::compute_diff_bias_row(int ocb) {
5569
5570 auto compute_step = [&]() {
5571 if (jcp.ddst_dt == data_type::bf16) {
5572 vmovups(vreg_bias_ddst, ptr[reg_ddst]);
5573 vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
5574 } else if (jcp.ddst_dt == data_type::f16) {
5575 // The ddst_dt is in vnni format, (S16c2s) which needs to be
5576 // reduced along S dimension. Since, we do not have f16_vnni
5577 // instruction, we try to emulate it.
5578 // ddst_data: [a,a, b,b ... p,p] in f16
5579 // req_output = [a+a, b+b, ... p+p] in f32 i.e., [A, B, ... P]
5580
5581 // [d,d, c,c, b,b, a,a] in f32 from now on
5582 vcvtph2psx(yreg_bias_ddst0, ptr[reg_ddst]);
5583 // [h,h, g,g, f,f, e,e] in f32 from now on
5584 vcvtph2psx(yreg_bias_ddst1, ptr[reg_ddst + 16]);
5585 // [h+h, g+g, d+d, c+c, f+f, e+e, b+b, a+a] i.e., [H, G, D, C, F, E, B, A]
5586 vhaddps(yreg_bias_ddst0, yreg_bias_ddst0, yreg_bias_ddst1);
5587 // accumulate with previous data
5588 vaddps(yreg_bias_acc0, yreg_bias_acc0, yreg_bias_ddst0);
5589
5590 vcvtph2psx(yreg_bias_ddst0, ptr[reg_ddst + 32]);
5591 vcvtph2psx(yreg_bias_ddst1, ptr[reg_ddst + 48]);
5592 // [p+p, o+o, l+l, k+k, n+n, m+m, j+j, i+i] i.e., [P, O, L, K, N, M, J, I]
5593 vhaddps(yreg_bias_ddst0, yreg_bias_ddst0, yreg_bias_ddst1);
5594 vaddps(yreg_bias_acc1, yreg_bias_acc1, yreg_bias_ddst0);
5595 }
5596 };
5597
5598 Label ow_loop;
5599 int niters = jcp.tr_ow / 2;
5600 if (niters > 0) {
5601 mov(reg_tmp, jcp.tr_ow / 2);
5602 L(ow_loop);
5603 compute_step();
5604 add(reg_ddst, get_ddst_offset(2));
5605 sub(reg_tmp, 1);
5606 jnz(ow_loop, T_NEAR);
5607 }
5608 if (jcp.tr_ow % 2) compute_step();
5609
5610 if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters));
5611}
5612
5613void jit_avx512_core_amx_bwd_bias_kernel_t::compute_diff_bias(
5614 int nb_oc_blocking) {
5615
5616 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
5617 Label bias_loop;
5618
5619 // pointer to diff_dst
5620 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
5621 add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size);
5622
5623 // number of rows
5624 mov(reg_oj, reg_nrows);
5625
5626 // accumulator initialization
5627 if (jcp.ddst_dt == data_type::f16) {
5628 vpxord(yreg_bias_acc0, yreg_bias_acc0, yreg_bias_acc0);
5629 vpxord(yreg_bias_acc1, yreg_bias_acc1, yreg_bias_acc1);
5630 } else {
5631 vpxord(vreg_bias_acc, vreg_bias_acc, vreg_bias_acc);
5632 }
5633 cmp(reg_initial, 0);
5634 jnz(bias_loop, T_NEAR);
5635 const size_t offset = sizeof(float) * ocb * jcp.oc_block;
5636 if (jcp.ddst_dt == data_type::f16) {
5637 // the data is in plain format, transform while loading.
5638 // i.e.,[H, G, F, E, D, C, B, A] -> [H, G, D, C, F, E, B, A]
5639 // and [P, O, N, M, L, K, J, I] -> [P, O, L, K, N, M, J, I]
5640 vpermq(yreg_bias_acc0, ptr[reg_bias + offset], 0xd8);
5641 vpermq(yreg_bias_acc1,
5642 ptr[reg_bias + offset + vreg_traits<Ymm>::vlen], 0xd8);
5643 } else {
5644 vmovups(vreg_bias_acc, ptr[reg_bias + offset]);
5645 }
5646 // loop by rows
5647 L(bias_loop);
5648 {
5649 compute_diff_bias_row(ocb);
5650 add(reg_ddst, get_ddst_offset(0, 1));
5651
5652 sub(reg_oj, 1);
5653 jnz(bias_loop, T_NEAR);
5654 }
5655
5656 // store accumulator
5657 if (jcp.ddst_dt == data_type::bf16) {
5658 vmovups(ptr[reg_bias + offset], vreg_bias_acc);
5659 } else if (jcp.ddst_dt == data_type::f16) {
5660 // transform to plain before storing.
5661 // i.e., [H, G, D, C, F, E, B, A] -> [H, G, F, E, D, C, B, A]
5662 // and [P, O, L, K, N, M, J, I] -> [P, O, N, M, L, K, J, I]
5663 vpermq(yreg_bias_acc0, yreg_bias_acc0, 0xd8);
5664 vpermq(yreg_bias_acc1, yreg_bias_acc1, 0xd8);
5665 vmovups(ptr[reg_bias + offset], yreg_bias_acc0);
5666 vmovups(ptr[reg_bias + offset + vreg_traits<Ymm>::vlen],
5667 yreg_bias_acc1);
5668 }
5669 }
5670}
5671
5672void jit_avx512_core_amx_bwd_bias_kernel_t::generate() {
5673 preamble();
5674
5675 Label end_label;
5676 // number of rows
5677 mov(reg_nrows, ptr[param + GET_OFF(os_index_end)]);
5678 sub(reg_nrows, ptr[param + GET_OFF(os_index_begin)]);
5679 cmp(reg_nrows, 0);
5680 jle(end_label, T_NEAR); // nothing to do
5681
5682 if (jcp.ddst_dt == data_type::bf16) {
5683 auto reg_unit_val = reg_tmp.cvt16();
5684 mov(reg_unit_val, 0x3f80); // bf16 value of 1.
5685 vpbroadcastw(vreg_bias_unit, reg_unit_val);
5686 }
5687 mov(reg_bias, ptr[param + GET_OFF(bias)]);
5688 mov(reg_initial, ptr[param + GET_OFF(channel)]);
5689
5690 Label last_oc_block_label;
5691 mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
5692 cmp(reg_tmp, 0);
5693 jne(last_oc_block_label, T_NEAR);
5694 { // full nb_oc_blocking
5695 compute_diff_bias(jcp.nb_oc_blocking);
5696 jmp(end_label, T_NEAR);
5697 }
5698 L(last_oc_block_label);
5699 { // tail of nb_oc_blocking
5700 compute_diff_bias(1);
5701 jmp(end_label, T_NEAR);
5702 }
5703
5704 L(end_label);
5705
5706 postamble();
5707}
5708// end of diff_bias kernel
5709
5710} // namespace x64
5711} // namespace cpu
5712} // namespace impl
5713} // namespace dnnl
5714
5715// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
5716