1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/memory_tracking.hpp" |
20 | #include "common/nstl.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/platform.hpp" |
25 | #include "cpu/scale_utils.hpp" |
26 | #include "cpu/x64/cpu_barrier.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
28 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
29 | #include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp" |
30 | |
31 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | namespace x64 { |
37 | |
38 | using namespace dnnl::impl::memory_tracking::names; |
39 | using namespace dnnl::impl::data_type; |
40 | using namespace dnnl::impl::utils; |
41 | using namespace Xbyak; |
42 | |
43 | void jit_avx512_core_amx_compute_zp_pbuff_t::prepare_output(int ur_w) { |
44 | for (int oc = 0; oc < jcp.nb_oc_blocking; oc++) |
45 | for (int ur = 0; ur < ur_w; ur++) { |
46 | const Zmm zmm = zmm_out(ur, oc); |
47 | vpxord(zmm, zmm, zmm); |
48 | } |
49 | } |
50 | |
51 | void jit_avx512_core_amx_compute_zp_pbuff_t::store_output( |
52 | int ur_w, bool last_oc_block_flag) { |
53 | assert(jcp.is_nspc); |
54 | |
55 | const int nb_oc_block = jcp.nb_oc_blocking; |
56 | const int oc_block = jcp.oc_block; |
57 | |
58 | const auto src_zp_addr = EVEX_compress_addr(reg_src_zero_point, 0, true); |
59 | |
60 | /* write out register to output_addr */ |
61 | for (int oc = 0; oc < nb_oc_block; oc++) { |
62 | const bool mask_flag = last_oc_block_flag && oc == nb_oc_block - 1; |
63 | for (int ur = 0; ur < ur_w; ur++) { |
64 | const int output_offset = sizeof(int32_t) |
65 | * (oc * oc_block |
66 | + ur * jcp.oc_without_padding * jcp.ngroups); |
67 | const Zmm zmm_dst = zmm_out(ur, oc); |
68 | const Zmm m_zmm_dst = mask_flag ? zmm_dst | ktail_mask : zmm_dst; |
69 | // multiply dst by src_zero_point |
70 | vpmulld(m_zmm_dst, zmm_dst, src_zp_addr); |
71 | vmovups(EVEX_compress_addr(reg_zp_pbuff, output_offset), m_zmm_dst); |
72 | } |
73 | } |
74 | } |
75 | |
76 | void jit_avx512_core_amx_compute_zp_pbuff_t::compute_ker(int ur_w, int pad_l, |
77 | int pad_r, ic_block_t last_ic_block_flag, bool padded) { |
78 | |
79 | const int kw = jcp.kw; |
80 | const int ic_block = jcp.ic_block_int_np; |
81 | const int oc_block = jcp.oc_block; |
82 | const int nb_oc_block = jcp.nb_oc_blocking; |
83 | |
84 | const bool ic_tail |
85 | = (jcp.ic_without_padding % (jcp.ic_block / ic_inner_block)) > 0; |
86 | const bool masked_write = ic_tail && last_ic_block_flag == last_ic_block; |
87 | |
88 | /* Skip the last loads of input |
89 | if (ic%16)/ic_sub_step < ic_block/ic_sub_step */ |
90 | const int icb = (last_ic_block_flag == last_ic_block) |
91 | ? div_up( |
92 | (jcp.ic_without_padding % jcp.ic_block_int), ic_inner_block) |
93 | : ic_block / ic_inner_block; |
94 | |
95 | auto get_filter_offset = [=](int ocb, int ic, int ki) { |
96 | size_t w_step = jcp.is_relo ? jcp.kh : 1; |
97 | size_t kw_offset = static_cast<size_t>(ki) * w_step |
98 | * jcp.ic_block_int_np * jcp.oc_block; |
99 | size_t oc_subblock_step = static_cast<size_t>(jcp.kd) * jcp.kh * jcp.kw |
100 | * jcp.ic_block_int_np * jcp.oc_block; |
101 | size_t offset = kw_offset |
102 | + static_cast<size_t>(ocb) * jcp.nb_ic_int * oc_subblock_step |
103 | + static_cast<size_t>(ic) * oc_block * ic_inner_block; |
104 | return sizeof(char) * offset; |
105 | }; |
106 | auto compute_fma = [=](const Zmm zmm_accum, const int ic, |
107 | const Address addr) { |
108 | if (jcp.is_relo) { |
109 | vmovups(zmm_permb, ptr[reg_scratch]); // get permute index table |
110 | const Zmm r_zmm = masked_write && ic == icb - 1 |
111 | ? zmm_permb | kmask_ic_block | T_z |
112 | : zmm_permb; |
113 | // only values from 'src2' are used to write dst |
114 | vpermi2b(r_zmm, zmm_permb, addr); |
115 | vpdpbusd(zmm_accum, zmm_one, |
116 | zmm_permb); // XXX - using the same register for all ur_w |
117 | } else { |
118 | vpdpbusd(zmm_accum, zmm_one, addr); |
119 | } |
120 | }; |
121 | |
122 | if (jcp.is_relo && last_ic_block_flag == last_ic_block && ic_tail) { |
123 | const Reg64 reg_tmp = reg_scratch; |
124 | mov(reg_tmp, ic_mask_label); |
125 | kmovq(kmask_ic_block, qword[reg_tmp]); |
126 | } |
127 | if (jcp.is_relo) mov(reg_scratch, permb_idx_label); |
128 | |
129 | for (int ki = 0; ki < kw; ki++) { |
130 | const int ur_start = get_ow_start(ki, pad_l); |
131 | const int ur_end = get_ow_end(ur_w, ki, pad_r); |
132 | for (int ur = 0; ur < ur_w; ur++) { |
133 | // Calculate zero_point padding as: |
134 | // accum = is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0) |
135 | if (ur < ur_start || ur >= ur_end || padded) { |
136 | for (int oc = 0; oc < nb_oc_block; oc++) { |
137 | const Zmm zmm_accum = zmm_out(ur, oc); |
138 | for (int ic = 0; ic < icb; ic++) { |
139 | const auto addr_filt = EVEX_compress_addr( |
140 | aux_reg_filt, get_filter_offset(oc, ic, ki)); |
141 | compute_fma(zmm_accum, ic, addr_filt); |
142 | } |
143 | } |
144 | } |
145 | } |
146 | } |
147 | } |
148 | |
149 | void jit_avx512_core_amx_compute_zp_pbuff_t::kh_loop(int ur_w, int pad_l, |
150 | int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) { |
151 | |
152 | Label kh_label, skip_kh_loop; |
153 | const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw; |
154 | const size_t shift_wei_h_step = sizeof(char) |
155 | * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np |
156 | * jcp.oc_block; |
157 | |
158 | // Compute zero_point compensation for the padded region. Total compute |
159 | // area is 'overflow * kw' where 'overflow' indicates the overlap |
160 | // between the filter and either top_pad or bottom_pad region. |
161 | auto compute_kh_loop = [=](size_t param_overflow) { |
162 | Label overflow_label, no_overflow_label; |
163 | |
164 | mov(reg_overflow, ptr[param1 + param_overflow]); |
165 | cmp(reg_overflow, 0); |
166 | je(no_overflow_label, T_NEAR); |
167 | L(overflow_label); |
168 | { |
169 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
170 | add(aux_reg_filt, shift_wei_h_step); |
171 | dec(reg_overflow); |
172 | jne(overflow_label, T_NEAR); |
173 | } |
174 | L(no_overflow_label); |
175 | }; |
176 | |
177 | if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(t_overflow)); |
178 | |
179 | // check for holes and skip computation due to dilation |
180 | mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); |
181 | if ((jcp.dilate_h >= jcp.ih)) { |
182 | cmp(reg_kj, 0); |
183 | je(skip_kh_loop, T_NEAR); |
184 | } |
185 | |
186 | L(kh_label); |
187 | { |
188 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false); |
189 | |
190 | add(aux_reg_filt, shift_wei_h_step); |
191 | dec(reg_kj); |
192 | jne(kh_label, T_NEAR); |
193 | } |
194 | |
195 | L(skip_kh_loop); |
196 | |
197 | if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(b_overflow)); |
198 | } |
199 | |
200 | void jit_avx512_core_amx_compute_zp_pbuff_t::kd_loop(int ur_w, int pad_l, |
201 | int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) { |
202 | |
203 | Label kd_label, skip_kd_loop; |
204 | const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw; |
205 | const size_t shift_wei_h_step = sizeof(char) |
206 | * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np |
207 | * jcp.oc_block; |
208 | |
209 | // Compute zero_point compensation for the padded region. Total compute |
210 | // area is 'overflow * kh * kw' where 'overflow' indicates the overlap |
211 | // between the filter and either front_pad or back_pad region. |
212 | auto compute_kd_loop = [=](size_t param_overflow) { |
213 | Label kh_loop_label; |
214 | Label no_overflow_label, overflow_label; |
215 | |
216 | mov(reg_ki, ptr[param1 + param_overflow]); |
217 | cmp(reg_ki, 0); |
218 | je(no_overflow_label, T_NEAR); |
219 | L(overflow_label); |
220 | { |
221 | mov(aux_reg_filt, aux_reg_filt_d); |
222 | mov(reg_kj, jcp.kh); |
223 | L(kh_loop_label); |
224 | { |
225 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
226 | add(aux_reg_filt, shift_wei_h_step); |
227 | dec(reg_kj); |
228 | jne(kh_loop_label, T_NEAR); |
229 | } |
230 | add(aux_reg_filt_d, shift_wei_h_step * jcp.kh); |
231 | dec(reg_ki); |
232 | jne(overflow_label, T_NEAR); |
233 | } |
234 | L(no_overflow_label); |
235 | }; |
236 | |
237 | const bool zp_d_padding |
238 | = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0); |
239 | if (zp_d_padding) { |
240 | mov(aux_reg_filt_d, reg_filt); |
241 | compute_kd_loop(GET_OFF(f_overflow)); |
242 | |
243 | // check for holes and skip computation due to dilation |
244 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
245 | if (jcp.dilate_d >= jcp.id) { |
246 | cmp(reg_ki, 0); |
247 | je(skip_kd_loop, T_NEAR); |
248 | } |
249 | L(kd_label); |
250 | mov(aux_reg_filt, aux_reg_filt_d); |
251 | |
252 | } else { |
253 | mov(aux_reg_filt, reg_filt); |
254 | } |
255 | |
256 | kh_loop(ur_w, pad_l, pad_r, last_ic_block_flag, handle_h_pad); |
257 | |
258 | if (zp_d_padding) { |
259 | add(aux_reg_filt_d, shift_wei_h_step * jcp.kh); |
260 | dec(reg_ki); |
261 | jne(kd_label, T_NEAR); |
262 | |
263 | L(skip_kd_loop); |
264 | |
265 | compute_kd_loop(GET_OFF(back_overflow)); |
266 | } |
267 | } |
268 | |
269 | void jit_avx512_core_amx_compute_zp_pbuff_t::icb_loop( |
270 | int ur_w, int pad_l, int pad_r, bool handle_h_pad) { |
271 | |
272 | Label icb_label; |
273 | const size_t nb_ic = jcp.nb_ic_int; |
274 | const bool do_icb_loop = nb_ic > 1; |
275 | |
276 | /* Initialize zmm_one for weight accumulation */ |
277 | xor_(reg_scratch, reg_scratch); |
278 | const Reg8 _t8 = reg_scratch.cvt8(); |
279 | mov(_t8, 0x1); |
280 | vpbroadcastb(zmm_one, _t8); |
281 | |
282 | prepare_output(ur_w); |
283 | |
284 | mov(reg_icb, nb_ic); |
285 | |
286 | L(icb_label); |
287 | if (jcp.ic_without_padding != jcp.ic) { |
288 | Label common_ker, end_ker; |
289 | if (do_icb_loop) { |
290 | cmp(reg_icb, 1); // The last ic block |
291 | jne(common_ker, T_NEAR); |
292 | } |
293 | kd_loop(ur_w, pad_l, pad_r, last_ic_block, handle_h_pad); |
294 | if (do_icb_loop) { |
295 | jmp(end_ker, T_NEAR); |
296 | |
297 | L(common_ker); |
298 | kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad); |
299 | |
300 | L(end_ker); |
301 | } |
302 | } else { |
303 | kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad); |
304 | } |
305 | // End of IC Loop |
306 | if (do_icb_loop) { |
307 | const size_t shift_wei_icb_step = static_cast<size_t>(jcp.kd) * jcp.kh |
308 | * jcp.kw * jcp.oc_block * jcp.ic_block_int_np; |
309 | add(reg_filt, sizeof(char) * shift_wei_icb_step); |
310 | |
311 | dec(reg_icb); |
312 | cmp(reg_icb, 0); |
313 | jg(icb_label, T_NEAR); |
314 | |
315 | sub(reg_filt, sizeof(char) * shift_wei_icb_step * nb_ic); |
316 | } |
317 | |
318 | if (jcp.oc_without_padding != jcp.oc) { |
319 | Label common_store, end_store; |
320 | |
321 | cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); |
322 | jne(common_store, T_NEAR); |
323 | |
324 | store_output(ur_w, true); // last oc block |
325 | jmp(end_store, T_NEAR); |
326 | |
327 | L(common_store); |
328 | store_output(ur_w, false); |
329 | |
330 | L(end_store); |
331 | } else { |
332 | store_output(ur_w, false); |
333 | } |
334 | } |
335 | |
336 | void jit_avx512_core_amx_compute_zp_pbuff_t::unroll_width( |
337 | const bool h_padding) { |
338 | |
339 | auto ur_w_shift = [&](const int ur_w) { |
340 | return sizeof(int32_t) * (ur_w * jcp.oc_without_padding * jcp.ngroups); |
341 | }; |
342 | |
343 | const int max_ur_w = jit_avx512_core_amx_compute_zp_pbuff_t::max_regs_ur |
344 | / (jcp.nb_oc_blocking); |
345 | const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
346 | int l_pad = jcp.l_pad; |
347 | |
348 | const int l_pad_output = jcp.l_pad_output; |
349 | const int r_pad_output = jcp.r_pad_output; |
350 | |
351 | // a single middle element (if required) containing only height padding |
352 | const int no_pad = nstl::max(0, jcp.ow - l_pad_output - r_pad_output); |
353 | |
354 | const int ow_start = nstl::max(jcp.ow - r_pad_output, l_pad_output); |
355 | const int r_pad_start = nstl::min(jcp.ow_pad - l_pad_output, r_pad_output); |
356 | |
357 | int ow = 0; |
358 | int cur_l_pad_output = l_pad_output; |
359 | while (cur_l_pad_output > 0) { |
360 | const int ur_w = nstl::min(cur_l_pad_output, max_ur_w); |
361 | ow += ur_w; |
362 | const int cur_r_pad = calculate_end_padding( |
363 | jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw); |
364 | icb_loop(ur_w, l_pad, cur_r_pad, h_padding); |
365 | add(reg_zp_pbuff, ur_w_shift(ur_w)); |
366 | |
367 | l_pad = nstl::max(l_pad - ur_w * jcp.stride_w, 0); |
368 | cur_l_pad_output = nstl::max(cur_l_pad_output - ur_w, 0); |
369 | } |
370 | |
371 | if (no_pad > 0) { |
372 | const int ur_w = 1; |
373 | if (h_padding) icb_loop(ur_w, 0, 0, true); |
374 | if (h_padding || jcp.ow_mid) add(reg_zp_pbuff, ur_w_shift(ur_w)); |
375 | } |
376 | assert(ow + no_pad == ow_start); |
377 | |
378 | ow = ow_start; |
379 | int cur_r_pad_output = r_pad_start; |
380 | while (cur_r_pad_output > 0 && ow < jcp.ow) { |
381 | const int ur_w = nstl::min(cur_r_pad_output, max_ur_w); |
382 | ow += ur_w; |
383 | const int cur_r_pad = calculate_end_padding( |
384 | jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw); |
385 | icb_loop(ur_w, 0, cur_r_pad, h_padding); |
386 | add(reg_zp_pbuff, ur_w_shift(ur_w)); |
387 | |
388 | cur_r_pad_output = nstl::max(cur_r_pad_output - ur_w, 0); |
389 | } |
390 | } |
391 | |
392 | void jit_avx512_core_amx_compute_zp_pbuff_t::generate() { |
393 | Label h_pad_label, end_label; |
394 | |
395 | assert(jcp.req_zero_point_buffer); |
396 | assert(jcp.typesize_in == sizeof(char)); |
397 | |
398 | preamble(); |
399 | |
400 | mov(reg_filt, ptr[param1 + GET_OFF(filt)]); |
401 | mov(reg_zp_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]); |
402 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
403 | |
404 | if (jcp.oc_without_padding != jcp.oc) { |
405 | const Reg32 reg_tmp = reg_scratch.cvt32(); |
406 | const int tail_size = jcp.oc_without_padding % jcp.oc_block; |
407 | const int mask = (1 << tail_size) - 1; |
408 | mov(reg_tmp, mask); |
409 | kmovw(ktail_mask, reg_tmp); |
410 | mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); |
411 | } |
412 | |
413 | mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); |
414 | cmp(reg_overflow, 0); |
415 | jne(h_pad_label, T_NEAR); |
416 | mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); |
417 | cmp(reg_overflow, 0); |
418 | jne(h_pad_label, T_NEAR); |
419 | if (jcp.ndims == 5 && (jcp.f_pad_output > 0 || jcp.back_pad_output > 0)) { |
420 | mov(reg_overflow, ptr[param1 + GET_OFF(kd_padding)]); |
421 | cmp(reg_overflow, jcp.kd); |
422 | jne(h_pad_label, T_NEAR); |
423 | } |
424 | |
425 | // Handle width padding region |
426 | unroll_width(false); |
427 | jmp(end_label, T_NEAR); |
428 | |
429 | // handle height padding region |
430 | L(h_pad_label); |
431 | unroll_width(true); |
432 | |
433 | L(end_label); |
434 | |
435 | postamble(); |
436 | |
437 | // reduced-lowering ('is_relo' == true) weights format is '..i16o', so |
438 | // permute elements through permb into the VNNI layout '...16i4i'. |
439 | if (jcp.is_relo) { |
440 | align(64); |
441 | L(permb_idx_label); |
442 | // permb: id-bit for table selection is bit[6] |
443 | const uint8_t select_src2_bit = 0x40; |
444 | // permb: bits [5:0] select the element within each input table |
445 | const uint8_t permb_idx_table[64] = {0, 16, 32, 48, 1, 17, 33, 49, 2, |
446 | 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, |
447 | 38, 54, 7, 23, 39, 55, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, |
448 | 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, |
449 | 62, 15, 31, 47, 63}; |
450 | for (size_t i = 0; i < 64; ++i) |
451 | db(select_src2_bit | permb_idx_table[i]); |
452 | |
453 | // write zero-mask (permb) for ic_tail in VNNI format '..16o4i' |
454 | const int ic_tail_size |
455 | = jcp.ic_without_padding % (jcp.ic_block / ic_inner_block); |
456 | if (jcp.ic_without_padding != jcp.ic && ic_tail_size > 0) { |
457 | align(64); |
458 | L(ic_mask_label); |
459 | |
460 | assert(4 > ic_tail_size); |
461 | // mask is on a 4-bit basis from the 4 ic elements in a zmm |
462 | const int nibble = (1 << ic_tail_size) - 1; |
463 | for (int i = 0; i < 16; ++i) { |
464 | db(nibble | (nibble << 4)); |
465 | } |
466 | } |
467 | } |
468 | } |
469 | |
470 | void jit_avx512_core_amx_copy_to_wbuffer_t::generate() { |
471 | |
472 | const bool is_bf16 = jcp.src_dt == data_type::bf16; |
473 | |
474 | // required for use of VPERMB instruction |
475 | assert(IMPLICATION(!is_bf16, cpu().has(Xbyak::util::Cpu::tAVX512_VBMI))); |
476 | assert(jcp.ic_block_int * jcp.typesize_in == 64); |
477 | |
478 | preamble(); |
479 | |
480 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
481 | mov(reg_dst, ptr[param1 + GET_OFF(dst)]); |
482 | |
483 | // load permute indices from data section |
484 | Label permute_index_table; |
485 | mov(reg_tmp, permute_index_table); |
486 | if (is_bf16) |
487 | vmovdqu16(zmm_idx, ptr[reg_tmp]); |
488 | else |
489 | vmovdqu8(zmm_idx, ptr[reg_tmp]); |
490 | |
491 | const int vnni_width = is_bf16 ? 2 : 4; |
492 | const int r = jcp.kh * jcp.kw * jcp.ic_without_padding; |
493 | const int nb_r = div_up(r, vnni_width); |
494 | const int rtail = (r % vnni_width) * jcp.oc_block; |
495 | if (rtail > 0) { |
496 | uint64_t mask = (UINT64_C(1) << rtail) - 1; |
497 | mov(reg_tmp, mask); |
498 | kmovq(kmask_load, reg_tmp); |
499 | } |
500 | const int nb_z = rnd_up(nb_r, jcp.ic_block); |
501 | if (nb_r < nb_z) vpxord(zmm_zero, zmm_zero, zmm_zero); |
502 | |
503 | const int tile_size = jcp.ic_block_int * jcp.oc_block * jcp.typesize_in; |
504 | const int ocb_src_step = r * jcp.oc_block * jcp.typesize_in; |
505 | const int ocb_dst_step = rnd_up(ocb_src_step, tile_size); |
506 | |
507 | // reorder from ~Owhi16o -> ~OR16oVr with r := whi and V := vnni_width |
508 | for (int g = 0; g < jcp.ngroups; g++) { |
509 | for (int ocb = 0; ocb < jcp.nb_oc; ocb++) { |
510 | int offset = 0; |
511 | int rb = 0; |
512 | for (; rb < nb_r; offset += 64, rb++) { |
513 | auto zmm_src_tmp = (rtail > 0 && rb == nb_r - 1) |
514 | ? zmm_src | kmask_load | T_z |
515 | : zmm_src; |
516 | if (is_bf16) { |
517 | vmovdqu16(zmm_src_tmp, ptr[reg_src + offset]); |
518 | vpermw(zmm_dst, zmm_idx, zmm_src); |
519 | vmovdqu16(ptr[reg_dst + offset], zmm_dst); |
520 | } else { |
521 | vmovdqu8(zmm_src_tmp, ptr[reg_src + offset]); |
522 | vpermb(zmm_dst, zmm_idx, zmm_src); |
523 | vmovdqu8(ptr[reg_dst + offset], zmm_dst); |
524 | } |
525 | } |
526 | for (; rb < nb_z; offset += 64, rb++) { |
527 | if (is_bf16) |
528 | vmovdqu16(ptr[reg_dst + offset], zmm_zero); |
529 | else |
530 | vmovdqu8(ptr[reg_dst + offset], zmm_zero); |
531 | } |
532 | add(reg_src, ocb_src_step); |
533 | add(reg_dst, ocb_dst_step); |
534 | } |
535 | } |
536 | |
537 | postamble(); |
538 | |
539 | align(64); |
540 | L(permute_index_table); |
541 | const uint8_t no = 16; // 16o |
542 | const uint8_t nr = is_bf16 ? 2 : 4; // 2r or 4r |
543 | for (uint8_t o = 0; o < no; ++o) { |
544 | for (uint8_t r = 0; r < nr; r++) { |
545 | const uint8_t index = o + r * no; |
546 | if (is_bf16) |
547 | dw(index); |
548 | else |
549 | db(index); |
550 | } |
551 | } |
552 | } |
553 | |
554 | void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_body( |
555 | int lpad, int iw_len, int icb) { |
556 | |
557 | const bool is_bf16 = jcp.src_dt == data_type::bf16; |
558 | int iwp_idx = 0; |
559 | // there are min(gen_kw, jcp.stride_w) continuous sets of input |
560 | // data (for each stride idx), they are placed one by one |
561 | // without additional padding |
562 | const bool are_sets_interleaved |
563 | = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1); |
564 | const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1; |
565 | const int num_sets = are_sets_interleaved ? jcp.n_stride_sets : jcp.kw; |
566 | for (int set_idx = 0; set_idx < num_sets; set_idx++) { |
567 | int set_width_padded = !jcp.is_pbuffer_strided |
568 | ? (jcp.ow_block - 1) * jcp.stride_w + gen_kw |
569 | : are_sets_interleaved ? jcp.ow_block - 1 + gen_kw / num_sets |
570 | + (set_idx < gen_kw % num_sets ? 1 : 0) |
571 | : jcp.ow_block; |
572 | for (int set_shift = 0; set_shift < set_width_padded; |
573 | set_shift++, iwp_idx++) { |
574 | int iw_idx = set_idx * (jcp.dilate_w + 1) |
575 | + set_shift * (jcp.is_pbuffer_strided ? jcp.stride_w : 1) |
576 | - lpad; |
577 | size_t out_base_offset |
578 | = (size_t)jcp.typesize_in * iwp_idx * jcp.ic_block_int_np; |
579 | if (iw_idx < 0 || iw_idx >= iw_len) { |
580 | // left or right padding |
581 | vmovups(ptr[reg_aux_out_ptr + out_base_offset], zmm_zero); |
582 | } else if (jcp.is_nspc) { |
583 | size_t inp_w_offset = (size_t)jcp.typesize_in * iw_idx |
584 | * jcp.ngroups * jcp.ic_without_padding; |
585 | int ic = icb * jcp.ic_block_int_np; |
586 | // TODO: use Xmm or Ymm moves for better small ic efficiency |
587 | auto zmm_tmp_mask |
588 | = ic + jcp.ic_block_int <= jcp.ic_without_padding |
589 | ? zmm_tmp |
590 | : zmm_tmp | ktail_mask | T_z; |
591 | if (is_bf16) { |
592 | vmovdqu16( |
593 | zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]); |
594 | vmovdqu16(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp); |
595 | } else { |
596 | vmovdqu8(zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]); |
597 | vmovdqu8(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp); |
598 | } |
599 | } else { |
600 | assert(is_bf16); |
601 | size_t inp_w_offset |
602 | = (size_t)jcp.typesize_in * iw_idx * jcp.ic_block; |
603 | for (int j = 0; j < jcp.ic_block_int_np / jcp.ic_block; j++) { |
604 | int ic = icb * jcp.ic_block_int_np + j * jcp.ic_block; |
605 | size_t inp_c_w_offset = (size_t)jcp.typesize_in * j * jcp.ih |
606 | * jcp.iw * jcp.ic_block |
607 | + inp_w_offset; |
608 | if (ic + jcp.ic_block <= jcp.ic) { |
609 | vmovdqu16( |
610 | ymm_tmp, ptr[reg_aux_inp_ptr + inp_c_w_offset]); |
611 | } else { |
612 | vpxord(ymm_tmp, ymm_tmp, ymm_tmp); |
613 | } |
614 | size_t out_offset = out_base_offset |
615 | + (size_t)jcp.typesize_in * j * jcp.ic_block; |
616 | vmovdqu16(ptr[reg_aux_out_ptr + out_offset], ymm_tmp); |
617 | } |
618 | } |
619 | } |
620 | } |
621 | } |
622 | |
623 | void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row(int icb) { |
624 | if (jcp.nb_ow == 1) { |
625 | copy_row_body(jcp.l_pad, jcp.iw, icb); |
626 | } else { |
627 | auto get_iw_len_required = [&](int cur_ow_block, int cur_lpad) { |
628 | return (cur_ow_block - 1) * jcp.stride_w |
629 | + (jcp.kw - 1) * (jcp.dilate_w + 1) + 1 - cur_lpad; |
630 | }; |
631 | |
632 | auto get_iw_len_limited = [&](int owb, int cur_ow_block, int cur_lpad) { |
633 | auto len_req = get_iw_len_required(cur_ow_block, cur_lpad); |
634 | if (owb < 0) return len_req; |
635 | int ow_block_start = nstl::max( |
636 | 0, owb * jcp.ow_block * jcp.stride_w - jcp.l_pad); |
637 | return nstl::min(jcp.iw - ow_block_start, len_req); |
638 | }; |
639 | |
640 | int general_owb_cases = jcp.nb_ow; |
641 | Xbyak::Label copy_row_done_label; |
642 | bool special_first_block_case = jcp.l_pad > 0; |
643 | if (special_first_block_case) { |
644 | general_owb_cases--; |
645 | Xbyak::Label skip_first_block_case_label; |
646 | cmp(reg_owb, 0); |
647 | jne(skip_first_block_case_label, T_NEAR); |
648 | copy_row_body(jcp.l_pad, |
649 | get_iw_len_limited(0, jcp.ow_block, jcp.l_pad), icb); |
650 | jmp(copy_row_done_label, T_NEAR); |
651 | L(skip_first_block_case_label); |
652 | } |
653 | bool special_last_block_case = false |
654 | // has ow_block_tail |
655 | || jcp.ow % jcp.ow_block != 0 |
656 | // there is no ow_block_tail but right padding exists |
657 | || get_iw_len_limited(jcp.nb_ow - 1, jcp.ow_block, 0) |
658 | != get_iw_len_required(jcp.ow_block, 0); |
659 | if (special_last_block_case) { |
660 | general_owb_cases--; |
661 | Xbyak::Label skip_last_block_case_label; |
662 | cmp(reg_owb, jcp.nb_ow - 1); |
663 | jne(skip_last_block_case_label, T_NEAR); |
664 | int ow_block_tail = jcp.ow % jcp.ow_block; |
665 | int cur_ow_block = ow_block_tail > 0 ? ow_block_tail : jcp.ow_block; |
666 | copy_row_body( |
667 | 0, get_iw_len_limited(jcp.nb_ow - 1, cur_ow_block, 0), icb); |
668 | jmp(copy_row_done_label, T_NEAR); |
669 | L(skip_last_block_case_label); |
670 | } |
671 | |
672 | bool special_penult_block_case = true |
673 | // if nb_ow = 2 and l_pad > 0 it's the same as |
674 | // special_first_block_case |
675 | && jcp.nb_ow >= (special_first_block_case ? 3 : 2) |
676 | // right padding exists in penult block |
677 | && get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0) |
678 | != get_iw_len_required(jcp.ow_block, 0); |
679 | if (special_penult_block_case) { |
680 | general_owb_cases--; |
681 | Xbyak::Label skip_penult_block_case_label; |
682 | cmp(reg_owb, jcp.nb_ow - 2); |
683 | jne(skip_penult_block_case_label, T_NEAR); |
684 | copy_row_body( |
685 | 0, get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0), icb); |
686 | jmp(copy_row_done_label, T_NEAR); |
687 | L(skip_penult_block_case_label); |
688 | } |
689 | |
690 | if (general_owb_cases > 0) // general case |
691 | copy_row_body(0, get_iw_len_required(jcp.ow_block, 0), icb); |
692 | |
693 | L(copy_row_done_label); |
694 | } |
695 | } |
696 | |
697 | void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_reduced_lowering() { |
698 | assert(jcp.nb_ic_int == 1); |
699 | assert(jcp.ic_block_int * jcp.typesize_in == 64); |
700 | assert(jcp.is_nspc); |
701 | |
702 | auto load_mask = [=](int tail, Opmask kmask) { |
703 | uint64_t mask = (UINT64_C(1) << tail) - 1; |
704 | mov(reg_tmp, mask); |
705 | kmovq(kmask, reg_tmp); |
706 | }; |
707 | |
708 | const bool is_bf16 = jcp.src_dt == data_type::bf16; |
709 | const int inp_w_step |
710 | = jcp.ngroups * jcp.ic_without_padding * jcp.typesize_in; |
711 | const int inp_h_step = jcp.iw * inp_w_step; |
712 | const int out_h_step = jcp.ic_without_padding * jcp.typesize_in; |
713 | const int out_w_step = jcp.kh * out_h_step; |
714 | const int tail_size = jcp.ic_without_padding % jcp.ic_block_int; |
715 | if (tail_size > 0) load_mask(tail_size, ktail_mask); |
716 | |
717 | auto zero_it = [=](reg64_t tmp_out_ptr) { |
718 | for (int ic = 0; ic < jcp.ic_without_padding; ic += jcp.ic_block_int) { |
719 | const int offset = ic * jcp.typesize_in; |
720 | const bool masked = ic + jcp.ic_block_int > jcp.ic_without_padding; |
721 | Zmm zmm = masked ? zmm_zero | ktail_mask : zmm_zero; |
722 | if (is_bf16) |
723 | vmovdqu16(ptr[tmp_out_ptr + offset], zmm); |
724 | else |
725 | vmovdqu8(ptr[tmp_out_ptr + offset], zmm); |
726 | } |
727 | }; |
728 | |
729 | // pointer to 1st needed element in src buffer |
730 | mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]); |
731 | // pointer to 1st needed element in dst buffer |
732 | mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]); |
733 | |
734 | // total number of rows to copy |
735 | mov(reg_kht, ptr[param1 + GET_OFF(kh_offset)]); |
736 | |
737 | // number of rows of src buffer to copy |
738 | mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]); |
739 | // number of zero-padded rows above src buffer to copy |
740 | mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]); |
741 | // number of zero-padded rows below src buffer to copy |
742 | mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]); |
743 | |
744 | // number of columns of src buffer to copy |
745 | mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]); |
746 | // number of zero-padded columns before src buffer to copy |
747 | mov(reg_lov, ptr[param1 + GET_OFF(f_overflow)]); |
748 | // number of zero-padded columns before src buffer to copy |
749 | mov(reg_rov, ptr[param1 + GET_OFF(back_overflow)]); |
750 | |
751 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
752 | |
753 | { // Handle Left Overflow |
754 | Label label_lov, label_lov_skip; |
755 | test(reg_lov, reg_lov); |
756 | jz(label_lov_skip, T_NEAR); |
757 | L(label_lov); // handle left or right overflow |
758 | { |
759 | Label label_lov_inner; |
760 | mov(reg_aux_out_ptr, reg_out_ptr); |
761 | mov(reg_cnt, reg_kht); |
762 | L(label_lov_inner); |
763 | { |
764 | zero_it(reg_aux_out_ptr); |
765 | add(reg_aux_out_ptr, out_h_step); |
766 | dec(reg_cnt); |
767 | jnz(label_lov_inner, T_NEAR); |
768 | } |
769 | add(reg_out_ptr, out_w_step); |
770 | dec(reg_lov); |
771 | jnz(label_lov, T_NEAR); |
772 | } |
773 | L(label_lov_skip); |
774 | } |
775 | |
776 | // save output pointer for later use |
777 | mov(reg_save_out_ptr, reg_out_ptr); |
778 | |
779 | // just in case there is no meat... |
780 | Label label_kwp_end; |
781 | test(reg_kwp, reg_kwp); |
782 | jz(label_kwp_end, T_NEAR); |
783 | |
784 | // Unroll over W-dimension in powers of 2 |
785 | Label label_tov; |
786 | Label label_khp, label_no_khp; |
787 | Label label_bov; |
788 | test(reg_tov, reg_tov); |
789 | jnz(label_tov, T_NEAR); |
790 | test(reg_khp, reg_khp); |
791 | jnz(label_khp, T_NEAR); |
792 | test(reg_bov, reg_bov); |
793 | jnz(label_bov, T_NEAR); |
794 | jmp(label_kwp_end, T_NEAR); // safe exit in case of bad parameters |
795 | |
796 | L(label_tov); // handle top overflow |
797 | { |
798 | Label label_tov_inner; |
799 | mov(reg_aux_out_ptr, reg_out_ptr); |
800 | mov(reg_cnt, reg_kwp); |
801 | L(label_tov_inner); |
802 | { |
803 | zero_it(reg_aux_out_ptr); |
804 | add(reg_aux_out_ptr, out_w_step); |
805 | dec(reg_cnt); |
806 | jnz(label_tov_inner, T_NEAR); |
807 | } |
808 | add(reg_out_ptr, out_h_step); |
809 | dec(reg_tov); |
810 | jnz(label_tov, T_NEAR); |
811 | } |
812 | test(reg_khp, reg_khp); |
813 | jz(label_no_khp, T_NEAR); |
814 | L(label_khp); // handle kh padding (not fully unrolled) |
815 | { |
816 | Label label_khp_inner; |
817 | mov(reg_aux_inp_ptr, reg_inp_ptr); |
818 | mov(reg_aux_out_ptr, reg_out_ptr); |
819 | mov(reg_cnt, reg_kwp); |
820 | L(label_khp_inner); |
821 | { |
822 | for (int ic = 0; ic < jcp.ic_without_padding; |
823 | ic += jcp.ic_block_int) { |
824 | const int offset = ic * jcp.typesize_in; |
825 | const bool masked |
826 | = ic + jcp.ic_block_int > jcp.ic_without_padding; |
827 | // zero masking is needed to avoid dependency on destination |
828 | Zmm zmm_load = masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp; |
829 | Zmm zmm_store = masked ? zmm_tmp | ktail_mask : zmm_tmp; |
830 | if (is_bf16) { |
831 | vmovdqu16(zmm_load, ptr[reg_aux_inp_ptr + offset]); |
832 | vmovdqu16(ptr[reg_aux_out_ptr + offset], zmm_store); |
833 | } else { |
834 | vmovdqu8(zmm_load, ptr[reg_aux_inp_ptr + offset]); |
835 | vmovdqu8(ptr[reg_aux_out_ptr + offset], zmm_store); |
836 | } |
837 | } |
838 | add(reg_aux_inp_ptr, inp_w_step); |
839 | add(reg_aux_out_ptr, out_w_step); |
840 | dec(reg_cnt); |
841 | jnz(label_khp_inner, T_NEAR); |
842 | } |
843 | add(reg_inp_ptr, inp_h_step); |
844 | add(reg_out_ptr, out_h_step); |
845 | dec(reg_khp); |
846 | jnz(label_khp, T_NEAR); |
847 | } |
848 | L(label_no_khp); |
849 | test(reg_bov, reg_bov); |
850 | jz(label_kwp_end, T_NEAR); |
851 | L(label_bov); // handle bottom overflow |
852 | { |
853 | Label label_bov_inner; |
854 | mov(reg_aux_out_ptr, reg_out_ptr); |
855 | mov(reg_cnt, reg_kwp); |
856 | L(label_bov_inner); |
857 | { |
858 | zero_it(reg_aux_out_ptr); |
859 | add(reg_aux_out_ptr, out_w_step); |
860 | dec(reg_cnt); |
861 | jnz(label_bov_inner, T_NEAR); |
862 | } |
863 | add(reg_out_ptr, out_h_step); |
864 | dec(reg_bov); |
865 | jnz(label_bov, T_NEAR); |
866 | } |
867 | L(label_kwp_end); |
868 | |
869 | { // Handle Right Overflow |
870 | Label label_rov, label_rov_skip; |
871 | // retrieve output pointer |
872 | mov(reg_out_ptr, reg_save_out_ptr); |
873 | // calculate the shift |
874 | imul(reg_tmp, reg_kwp, out_w_step); |
875 | // shift past the body |
876 | add(reg_out_ptr, reg_tmp); |
877 | // skip if no right overflow |
878 | test(reg_rov, reg_rov); |
879 | jz(label_rov_skip, T_NEAR); |
880 | |
881 | L(label_rov); // handle left or right overflow |
882 | { |
883 | Label label_rov_inner; |
884 | mov(reg_aux_out_ptr, reg_out_ptr); |
885 | mov(reg_cnt, reg_kht); |
886 | L(label_rov_inner); |
887 | { |
888 | zero_it(reg_aux_out_ptr); |
889 | add(reg_aux_out_ptr, out_h_step); |
890 | dec(reg_cnt); |
891 | jnz(label_rov_inner, T_NEAR); |
892 | } |
893 | add(reg_out_ptr, out_w_step); |
894 | dec(reg_rov); |
895 | jnz(label_rov, T_NEAR); |
896 | } |
897 | L(label_rov_skip); |
898 | } |
899 | |
900 | // For bf16, zero-pad an extra cacheline to avoid NaNs |
901 | // For int8, it is sufficient to zero-pad the weights only |
902 | if (is_bf16) { |
903 | // shift forward to align h index to end of needed buffer |
904 | imul(reg_tmp, reg_kht, out_h_step); |
905 | add(reg_out_ptr, reg_tmp); |
906 | // shift backward to align w index to end of needed buffer |
907 | sub(reg_out_ptr, out_w_step); |
908 | vmovdqu16(ptr[reg_out_ptr], zmm_zero); |
909 | } |
910 | } |
911 | |
912 | void jit_avx512_core_amx_copy_to_pbuffer_t::generate() { |
913 | |
914 | // Special copy kernel for reduced lowering |
915 | if (jcp.is_relo) { |
916 | assert(jcp.nb_ic_int == 1); |
917 | preamble(); |
918 | copy_row_reduced_lowering(); |
919 | postamble(); |
920 | return; |
921 | } |
922 | |
923 | preamble(); |
924 | |
925 | const bool is_3d = jcp.ndims == 5; |
926 | mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]); |
927 | mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]); |
928 | if (is_3d) mov(reg_kdp, ptr[param1 + GET_OFF(kd_padding)]); |
929 | mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]); |
930 | mov(reg_tover, ptr[param1 + GET_OFF(t_overflow)]); |
931 | mov(reg_bover, ptr[param1 + GET_OFF(b_overflow)]); |
932 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
933 | |
934 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
935 | |
936 | if (jcp.is_nspc && jcp.ic_without_padding % jcp.ic_block_int) { |
937 | int tail_size = jcp.ic_without_padding % jcp.ic_block_int; |
938 | uint64_t mask = (UINT64_C(1) << tail_size) - 1; |
939 | mov(reg_tmp, mask); |
940 | kmovq(ktail_mask, reg_tmp); |
941 | } |
942 | |
943 | for (int icb = 0; icb < jcp.nb_ic_int; icb++) { |
944 | Xbyak::Label kd_label, no_kd_label; |
945 | Xbyak::Label kh_label, no_kh_label, icb_label; |
946 | Xbyak::Label kh_tover_label, kh_bover_label; |
947 | Xbyak::Label no_kh_tover_label, no_kh_bover_label; |
948 | |
949 | mov(reg_aux_inp_ptr, reg_inp_ptr); |
950 | mov(reg_aux_out_ptr, reg_out_ptr); |
951 | if (is_3d) { |
952 | cmp(reg_kdp, 0); |
953 | jle(no_kd_label, T_NEAR); |
954 | mov(reg_kdc, reg_kdp); |
955 | L(kd_label); |
956 | push(reg_aux_inp_ptr); |
957 | push(reg_aux_out_ptr); |
958 | } |
959 | cmp(reg_khp, 0); |
960 | jle(no_kh_bover_label, T_NEAR); // nothing to do |
961 | mov(reg_khc, reg_khp); |
962 | |
963 | cmp(reg_tover, 0); |
964 | jle(no_kh_tover_label, T_NEAR); |
965 | |
966 | mov(reg_kh_over, reg_tover); |
967 | L(kh_tover_label); |
968 | { |
969 | // TODO: adjust step to improve zeroing efficiency for small ic |
970 | for (int iw = 0; iw < jcp.iwp; iw++) |
971 | vmovups(ptr[reg_aux_out_ptr |
972 | + jcp.typesize_in * iw * jcp.ic_block_int_np], |
973 | zmm_zero); |
974 | int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np; |
975 | add(reg_aux_out_ptr, out_h_offset); |
976 | |
977 | dec(reg_kh_over); |
978 | jnz(kh_tover_label, T_NEAR); |
979 | } |
980 | sub(reg_khc, reg_tover); |
981 | L(no_kh_tover_label); |
982 | |
983 | cmp(reg_khc, reg_bover); |
984 | jle(no_kh_label, T_NEAR); |
985 | |
986 | L(kh_label); |
987 | { |
988 | copy_row(icb); |
989 | size_t inp_h_offset = !jcp.is_nspc |
990 | ? (size_t)jcp.typesize_in * jcp.iw * jcp.ic_block |
991 | : (size_t)jcp.typesize_in * jcp.iw * jcp.ngroups |
992 | * jcp.ic_without_padding; |
993 | size_t out_h_offset |
994 | = (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np; |
995 | |
996 | add(reg_aux_inp_ptr, inp_h_offset); |
997 | add(reg_aux_out_ptr, out_h_offset); |
998 | |
999 | dec(reg_khc); |
1000 | cmp(reg_khc, reg_bover); |
1001 | jg(kh_label, T_NEAR); |
1002 | } |
1003 | L(no_kh_label); |
1004 | |
1005 | cmp(reg_khc, 0); |
1006 | jle(no_kh_bover_label, T_NEAR); |
1007 | |
1008 | L(kh_bover_label); |
1009 | { |
1010 | // TODO: adjust step to improve zeroing efficiency for small ic |
1011 | for (int iw = 0; iw < jcp.iwp; iw++) |
1012 | vmovups(ptr[reg_aux_out_ptr |
1013 | + jcp.typesize_in * iw * jcp.ic_block_int_np], |
1014 | zmm_zero); |
1015 | int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np; |
1016 | add(reg_aux_out_ptr, out_h_offset); |
1017 | |
1018 | dec(reg_khc); |
1019 | jnz(kh_bover_label, T_NEAR); |
1020 | } |
1021 | size_t out_d_offset = (size_t)jcp.typesize_in |
1022 | * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int); |
1023 | L(no_kh_bover_label); |
1024 | if (is_3d) { |
1025 | size_t inp_d_offset = !jcp.is_nspc |
1026 | ? (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block |
1027 | * (jcp.dilate_d + 1) |
1028 | : (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ngroups |
1029 | * jcp.ic_without_padding * (jcp.dilate_d + 1); |
1030 | pop(reg_aux_out_ptr); |
1031 | pop(reg_aux_inp_ptr); |
1032 | add(reg_aux_inp_ptr, inp_d_offset); |
1033 | add(reg_aux_out_ptr, out_d_offset); |
1034 | dec(reg_kdc); |
1035 | jnz(kd_label, T_NEAR); |
1036 | L(no_kd_label); |
1037 | } |
1038 | // End IC Loop |
1039 | size_t inp_cb_offset = !jcp.is_nspc |
1040 | ? (size_t)jcp.typesize_in * (jcp.ic_block_int_np / jcp.ic_block) |
1041 | * jcp.id * jcp.ih * jcp.iw * jcp.ic_block |
1042 | : (size_t)jcp.typesize_in * jcp.ic_block_int_np; |
1043 | size_t out_cb_offset = (size_t)jcp.kd * out_d_offset; |
1044 | |
1045 | add(reg_inp_ptr, inp_cb_offset); |
1046 | add(reg_out_ptr, out_cb_offset); |
1047 | } |
1048 | |
1049 | postamble(); |
1050 | } |
1051 | |
1052 | jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t( |
1053 | const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, |
1054 | const memory_desc_t &dst_md) |
1055 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx) |
1056 | , jcp(ajcp) |
1057 | , attr_(attr) { |
1058 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
1059 | using namespace binary_injector; |
1060 | const auto &rhs_addr_reg = bin_injector_helper_reg_1; |
1061 | const auto &rhs_helper_reg = bin_injector_helper_reg_2; |
1062 | const auto &rhs_addr_cache_reg = bin_injector_helper_reg_3; |
1063 | static constexpr bool preserve_gpr = false; |
1064 | static constexpr bool preserve_vmm = false; |
1065 | const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; |
1066 | static constexpr bool use_exact_tail_scalar_bcast = true; |
1067 | |
1068 | const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { |
1069 | 31, rhs_addr_reg, rhs_helper_reg, rhs_addr_cache_reg, |
1070 | preserve_gpr, preserve_vmm, |
1071 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
1072 | memory_desc_wrapper(dst_md), tail_size, ktail_mask, |
1073 | use_exact_tail_scalar_bcast}; |
1074 | const binary_injector::static_params_t static_params { |
1075 | this->param1, rhs_arg_static_params}; |
1076 | |
1077 | postops_injector_ = utils::make_unique< |
1078 | injector::jit_uni_postops_injector_t<avx512_core>>( |
1079 | this, jcp.post_ops, static_params); |
1080 | } |
1081 | copy_to_pbuffer_ |
1082 | = utils::make_unique<jit_avx512_core_amx_copy_to_pbuffer_t>(jcp); |
1083 | if (jcp.is_relo) |
1084 | copy_to_wbuffer_ |
1085 | = utils::make_unique<jit_avx512_core_amx_copy_to_wbuffer_t>( |
1086 | jcp); |
1087 | } |
1088 | |
1089 | status_t jit_avx512_core_amx_fwd_kernel_t::create_kernel() { |
1090 | CHECK(jit_generator::create_kernel()); |
1091 | CHECK(copy_to_pbuffer_->create_kernel()); |
1092 | if (jcp.is_relo) CHECK(copy_to_wbuffer_->create_kernel()); |
1093 | if (jcp.req_zero_point_buffer) { |
1094 | zp_pbuff_kernel_ |
1095 | = utils::make_unique<jit_avx512_core_amx_compute_zp_pbuff_t>( |
1096 | jcp); |
1097 | if (zp_pbuff_kernel_ == nullptr) return status::out_of_memory; |
1098 | CHECK(zp_pbuff_kernel_->create_kernel()); |
1099 | } |
1100 | return status::success; |
1101 | } |
1102 | |
1103 | // Tile register decomposition |
1104 | // { C_BASE = 0, I_BASE = 4, W_BASE = 6, } |
1105 | int jit_avx512_core_amx_fwd_kernel_t::get_out_tensor( |
1106 | int h, int i, bool is_h_tail) const { |
1107 | const int C_BASE = 0; |
1108 | const int C_LAST = 4; |
1109 | assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles); |
1110 | MAYBE_UNUSED(C_LAST); |
1111 | const int tile = C_BASE |
1112 | + (jcp.nb_oh_blocking > 1 |
1113 | ? h * jcp.nb_oh_blocking + i |
1114 | : (int)is_h_tail * jcp.nb_oc_blocking + i); |
1115 | assert(C_BASE <= tile && tile < C_LAST); |
1116 | return tile; |
1117 | } |
1118 | int jit_avx512_core_amx_fwd_kernel_t::get_inp_tensor( |
1119 | int h, bool is_h_tail) const { |
1120 | const int I_BASE = 4; |
1121 | const int I_LAST = 6; |
1122 | assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles); |
1123 | MAYBE_UNUSED(I_LAST); |
1124 | const int tile = I_BASE + (jcp.nb_oh_blocking > 1 ? h : (int)is_h_tail); |
1125 | assert(I_BASE <= tile && tile < I_LAST); |
1126 | return tile; |
1127 | } |
1128 | int jit_avx512_core_amx_fwd_kernel_t::get_wei_tensor(int i) const { |
1129 | const int W_BASE = 6; |
1130 | const int W_LAST = 8; |
1131 | assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles); |
1132 | MAYBE_UNUSED(W_LAST); |
1133 | const int tile = W_BASE + i; |
1134 | assert(W_BASE <= tile && tile < W_LAST); |
1135 | return tile; |
1136 | } |
1137 | |
1138 | // Shifts and offsets |
1139 | size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_icb_step() const { |
1140 | return (size_t)jcp.kd * get_inp_d_step(); |
1141 | } |
1142 | size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_icb_step() const { |
1143 | return (size_t)jcp.typesize_in * jcp.kd * jcp.kh * jcp.kw |
1144 | * jcp.ic_block_int_np * jcp.oc_block; |
1145 | } |
1146 | size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_d_step() const { |
1147 | return (size_t)jcp.typesize_in |
1148 | * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int); |
1149 | } |
1150 | size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_h_step() const { |
1151 | return (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np |
1152 | * (jcp.dilate_h + 1); |
1153 | } |
1154 | size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_d_step() const { |
1155 | return (size_t)jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block_int_np |
1156 | * jcp.oc_block; |
1157 | } |
1158 | size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_h_step() const { |
1159 | return (size_t)jcp.typesize_in * jcp.kw * jcp.ic_block_int_np |
1160 | * jcp.oc_block; |
1161 | } |
1162 | size_t jit_avx512_core_amx_fwd_kernel_t::get_out_ocb_offset( |
1163 | int ohb, int ocb, size_t typesize) const { |
1164 | size_t el_offset = jcp.is_nspc |
1165 | ? (size_t)ocb * jcp.oc_block |
1166 | + (size_t)ohb * jcp.ow * jcp.ngroups |
1167 | * jcp.oc_without_padding |
1168 | : (size_t)ocb * jcp.oh * jcp.ow * jcp.oc_block |
1169 | + (size_t)ohb * jcp.ow * jcp.oc_block; |
1170 | return (size_t)typesize * el_offset; |
1171 | } |
1172 | size_t jit_avx512_core_amx_fwd_kernel_t::get_out_row_offset( |
1173 | int ohb, int ocb, int j, size_t typesize) const { |
1174 | size_t offset_w = jcp.is_nspc |
1175 | ? (size_t)typesize * j * jcp.ngroups * jcp.oc_without_padding |
1176 | : (size_t)typesize * j * jcp.oc_block; |
1177 | return get_out_ocb_offset(ohb, ocb, typesize) + offset_w; |
1178 | } |
1179 | size_t jit_avx512_core_amx_fwd_kernel_t::get_out_shift( |
1180 | int width, size_t typesize) const { |
1181 | return jcp.is_nspc |
1182 | ? (size_t)typesize * width * jcp.ngroups * jcp.oc_without_padding |
1183 | : (size_t)typesize * width * jcp.oc_block; |
1184 | } |
1185 | size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_ocb_offset( |
1186 | int ohb, int ocb) const { |
1187 | size_t el_offset = (size_t)ocb * prv_width_ * jcp.oc_block |
1188 | + (size_t)ohb * jcp.nb_oc_blocking * jcp.full_tile_width |
1189 | * jcp.oc_block; |
1190 | return jcp.typesize_acc * el_offset; |
1191 | } |
1192 | size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_row_offset( |
1193 | int ohb, int ocb, int j) const { |
1194 | return get_wsp_ocb_offset(ohb, ocb) |
1195 | + (size_t)jcp.typesize_acc * j * jcp.oc_block; |
1196 | } |
1197 | size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_shift() const { |
1198 | return (size_t)jcp.typesize_acc * jcp.nb_oh_blocking * jcp.full_tile_width |
1199 | * jcp.oc_block * jcp.nb_oc_blocking; |
1200 | } |
1201 | size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_offset(int ocb, int kw) const { |
1202 | size_t el_offset = (size_t)kw * jcp.ic_block_int_np * jcp.oc_block; |
1203 | size_t raw_oc_subblock_step |
1204 | = jcp.kd * jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block; |
1205 | size_t oc_subblock_step = jcp.is_relo |
1206 | ? rnd_up(raw_oc_subblock_step, jcp.ic_block_int * jcp.oc_block) |
1207 | : raw_oc_subblock_step; |
1208 | el_offset += (size_t)ocb * jcp.nb_ic_int * oc_subblock_step; |
1209 | return jcp.typesize_in * el_offset; |
1210 | } |
1211 | size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_shift() const { |
1212 | size_t w_step = (jcp.is_relo ? jcp.stride_w * jcp.kh |
1213 | : jcp.is_pbuffer_strided ? 1 : jcp.stride_w) |
1214 | * jcp.ic_block_int_np; |
1215 | return (size_t)jcp.typesize_in * jcp.tile_width * w_step; |
1216 | } |
1217 | size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_offset(int ohb, int kw) const { |
1218 | if (jcp.is_relo) |
1219 | return ohb * jcp.iwp * jcp.kh * jcp.ic_block_int_np * jcp.typesize_in; |
1220 | // calculate offset by height dimension |
1221 | const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1; |
1222 | const int gen_stride_h = nstl::min(jcp.stride_h, gen_kh); |
1223 | size_t el_offset = (size_t)ohb * jcp.oh_per_tile * gen_stride_h * jcp.iwp |
1224 | * jcp.ic_block_int_np; |
1225 | |
1226 | // add offset by width dimension |
1227 | if (IMPLICATION(jcp.is_pbuffer_strided, jcp.stride_w == 1)) { |
1228 | el_offset += (size_t)kw * (jcp.dilate_w + 1) * jcp.ic_block_int_np; |
1229 | } else if (jcp.dilate_w > 0) { |
1230 | el_offset += (size_t)kw * jcp.ow_block * jcp.ic_block_int_np; |
1231 | } else { |
1232 | // dilate_w == 0 && stride_w > 1 |
1233 | // there are min(jcp.kw, jcp.stride_w) continuous sets of input data |
1234 | // (foreach stride idx), they are placed one by one without additional |
1235 | // padding |
1236 | |
1237 | // calculate set idx for current kw value |
1238 | int set_idx = kw % jcp.stride_w; |
1239 | // calculate shift within set for current kw value |
1240 | int set_shift = kw / jcp.stride_w; |
1241 | |
1242 | // calculate the beginning of the current set along width, each set |
1243 | // with index set_i contains number of elements along width equal to |
1244 | // jcp.ow - 1 + jcp.kw / jcp.stride_w |
1245 | // + (set_i < jcp.kw % jcp.stride_w) |
1246 | size_t set_start = (jcp.ow_block - 1 + jcp.kw / jcp.stride_w) * set_idx |
1247 | + nstl::min(set_idx, jcp.kw % jcp.stride_w); |
1248 | el_offset += (set_start + set_shift) * jcp.ic_block_int_np; |
1249 | } |
1250 | return jcp.typesize_in * el_offset; |
1251 | } |
1252 | |
1253 | size_t jit_avx512_core_amx_fwd_kernel_t::get_zp_comp_offset( |
1254 | int ocb, int zp_h, int zp_w) const { |
1255 | const size_t ocb_offset = (size_t)ocb * jcp.oc_block; |
1256 | const size_t sp_offset = (size_t)(zp_h * jcp.ow_pad + zp_w) * jcp.ngroups |
1257 | * jcp.oc_without_padding; |
1258 | return (ocb_offset + sp_offset) * sizeof(int32_t); |
1259 | } |
1260 | |
1261 | int jit_avx512_core_amx_fwd_kernel_t::get_zp_index_offset( |
1262 | int index, int mid, int s_pad_output, int e_pad_output) { |
1263 | using namespace nstl; |
1264 | const int mid_end = e_pad_output - 1; |
1265 | int zp_mid = min(mid, max(0, index - mid_end)); |
1266 | int zp_pad_offset |
1267 | = accum_with_upper_bound(index, s_pad_output, e_pad_output); |
1268 | return zp_pad_offset + zp_mid; |
1269 | } |
1270 | |
1271 | // Code generation |
1272 | void jit_avx512_core_amx_fwd_kernel_t::prepare_output(int tail) { |
1273 | for (int h = 0; h < jcp.nb_oh_blocking; h++) |
1274 | for (int i = 0; i < jcp.nb_oc_blocking; i++) |
1275 | tilezero(Tmm(get_out_tensor(h, i, tail))); |
1276 | } |
1277 | |
1278 | void jit_avx512_core_amx_fwd_kernel_t::init_runtime_counters( |
1279 | bool start_with_last_tile_block) { |
1280 | prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0 |
1281 | ? jcp.tile_tail |
1282 | : jcp.tile_width; |
1283 | |
1284 | row_count_ = 0; |
1285 | is_store_done_ = false; |
1286 | is_buffer_empty_ = true; |
1287 | } |
1288 | |
1289 | size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_block( |
1290 | const int block_size, const int pad_output) { |
1291 | return (size_t)(pad_output >= block_size ? block_size : 0) |
1292 | + (pad_output % block_size); |
1293 | } |
1294 | |
1295 | size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_blocked_dims( |
1296 | const int dim_size, const int block_size, const int s_pad_output, |
1297 | const int e_pad_output) { |
1298 | using namespace nstl; |
1299 | |
1300 | // start padding (s_pad) |
1301 | int s_pad_limit = reduce_to_block(block_size, s_pad_output); |
1302 | int s_pad_area_blk = rnd_up(s_pad_limit, block_size); |
1303 | |
1304 | // middle (no padding) |
1305 | int no_pad_area = max( |
1306 | 0, dim_size - rnd_up(s_pad_output, block_size) - e_pad_output); |
1307 | int no_pad_limit = (no_pad_area >= block_size ? block_size : 0); |
1308 | |
1309 | // end padding (e_pad) |
1310 | int no_pad_area_shift = no_pad_area % block_size; |
1311 | int e_pad_area_overlap |
1312 | = no_pad_area_shift == 0 ? 0 : block_size - no_pad_area_shift; |
1313 | // middle and end padding shift |
1314 | int e_pad_shift_limit |
1315 | = no_pad_area_shift + min(e_pad_output, e_pad_area_overlap); |
1316 | int e_pad_area_blk = max(0, e_pad_output - e_pad_area_overlap); |
1317 | // full end padding block |
1318 | int e_pad_limit = reduce_to_block(block_size, e_pad_area_blk); |
1319 | |
1320 | // calculate reduced size of s_pad, middle and e_pad blocks. |
1321 | return min((size_t)dim_size, |
1322 | (size_t)s_pad_area_blk + no_pad_limit + e_pad_shift_limit |
1323 | + e_pad_limit); |
1324 | } |
1325 | |
1326 | Ymm jit_avx512_core_amx_fwd_kernel_t::ymm_mask( |
1327 | const Ymm &ymm_in, bool mask_flag, bool store) { |
1328 | return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z) |
1329 | : ymm_in; |
1330 | } |
1331 | |
1332 | Zmm jit_avx512_core_amx_fwd_kernel_t::zmm_mask( |
1333 | const Zmm &zmm_in, bool mask_flag, bool store) { |
1334 | return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) |
1335 | : zmm_in; |
1336 | } |
1337 | |
1338 | void jit_avx512_core_amx_fwd_kernel_t::cvt2ps(data_type_t type_in, |
1339 | const Zmm &zmm_in, const Operand &op, bool mask_flag) { |
1340 | const Zmm zmm = zmm_mask(zmm_in, mask_flag); |
1341 | switch (type_in) { |
1342 | case data_type::f32: |
1343 | case data_type::s32: vmovups(zmm, op); break; |
1344 | case data_type::s8: vpmovsxbd(zmm, op); break; |
1345 | case data_type::u8: vpmovzxbd(zmm, op); break; |
1346 | case data_type::bf16: |
1347 | vpmovzxwd(zmm, op); |
1348 | vpslld(zmm, zmm, 0x10); |
1349 | break; |
1350 | default: assert(!"unsupported data type" ); |
1351 | } |
1352 | if (!utils::one_of(type_in, data_type::f32, data_type::bf16)) |
1353 | vcvtdq2ps(zmm_in, zmm_in); |
1354 | } |
1355 | |
1356 | void jit_avx512_core_amx_fwd_kernel_t::apply_sum(const Zmm &zmm_out, |
1357 | const float *p_sum_scale, const int32_t *p_sum_zp, |
1358 | const Xbyak::Address &addr, const bool mask_flag) { |
1359 | if (p_sum_scale) { |
1360 | const float p_sum_scale_val = *p_sum_scale; |
1361 | const int32_t p_sum_zp_val = *p_sum_zp; |
1362 | const auto sum_injector = [&, p_sum_scale_val, p_sum_zp_val, |
1363 | mask_flag]() { |
1364 | cvt2ps(jcp.sum_dt, zmm_prev_dst, addr, mask_flag); |
1365 | if (p_sum_zp_val != 0) { |
1366 | vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]); |
1367 | vsubps(zmm_prev_dst, zmm_sum_zp); |
1368 | } |
1369 | if (p_sum_scale_val == 1.f) |
1370 | vaddps(zmm_out, zmm_prev_dst); |
1371 | else |
1372 | vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); |
1373 | }; |
1374 | postops_injector_->set_lambda_injector( |
1375 | primitive_kind::sum, sum_injector); |
1376 | } |
1377 | } |
1378 | |
1379 | void jit_avx512_core_amx_fwd_kernel_t::apply_postops(const Zmm &zmm_out, |
1380 | const float *p_sum_scale, const int32_t *p_sum_zp, |
1381 | const Xbyak::Address &addr, const size_t off, const bool mask_flag) { |
1382 | if (jcp.with_eltwise || jcp.with_binary |
1383 | || (jcp.with_sum && p_sum_scale != nullptr)) { |
1384 | apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag); |
1385 | |
1386 | const auto vmm_idx = zmm_out.getIdx(); |
1387 | if (jcp.with_binary) { |
1388 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
1389 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_out_ptr); |
1390 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off); |
1391 | if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
1392 | |
1393 | postops_injector_->compute_vector(vmm_idx, rhs_arg_params); |
1394 | } else { |
1395 | postops_injector_->compute_vector(vmm_idx); |
1396 | } |
1397 | } |
1398 | } |
1399 | |
1400 | void jit_avx512_core_amx_fwd_kernel_t::store_output_ymm_bf16( |
1401 | const int idx, const Xbyak::Address &addr, const bool mask_flag) { |
1402 | Ymm ymm_out = Ymm(idx); |
1403 | vcvtneps2bf16(ymm_out, Zmm(idx)); |
1404 | vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true)); |
1405 | } |
1406 | |
1407 | void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_bf16( |
1408 | const Zmm &zmm_out, int ocb, int h, int w) { |
1409 | const bool mask_flag = jcp.is_nspc && jcp.oc_without_padding != jcp.oc |
1410 | && ocb == (jcp.nb_oc_blocking - 1); |
1411 | |
1412 | const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out); |
1413 | auto addr = EVEX_compress_addr(reg_out_ptr, off); |
1414 | |
1415 | const auto &p = attr_.post_ops_; |
1416 | |
1417 | const int sum_idx = p.find(primitive_kind::sum); |
1418 | if (sum_idx != -1) { |
1419 | if (jcp.dst_dt == data_type::bf16) { |
1420 | vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr); |
1421 | vpslld(zmm_prev_dst, zmm_prev_dst, 16); |
1422 | vaddps(zmm_out, zmm_prev_dst); |
1423 | } else { |
1424 | vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr); |
1425 | vaddps(zmm_out, zmm_prev_dst); |
1426 | } |
1427 | } |
1428 | if (jcp.with_bias) { |
1429 | int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; |
1430 | auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
1431 | if (jcp.bia_dt == data_type::bf16) { |
1432 | vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr); |
1433 | vpslld(zmm_bias, zmm_bias, 16); |
1434 | vaddps(zmm_out, zmm_bias); |
1435 | } else |
1436 | vaddps(zmm_mask(zmm_out, mask_flag), bias_addr); |
1437 | } |
1438 | |
1439 | static constexpr auto skip_sum_injection = nullptr; |
1440 | apply_postops(zmm_out, skip_sum_injection, skip_sum_injection, addr, off, |
1441 | mask_flag); |
1442 | |
1443 | if (jcp.dst_dt == data_type::bf16) { |
1444 | store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag); |
1445 | } else { |
1446 | vmovups(addr, zmm_mask(zmm_out, mask_flag, true)); |
1447 | } |
1448 | } |
1449 | |
1450 | void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_int8( |
1451 | const Zmm &zmm_out, int ocb, int h, int w, const bool compute_zp, |
1452 | const int zp_h, const int zp_w) { |
1453 | const int nb_oc_block = jcp.nb_oc_blocking; |
1454 | const int oc_block = jcp.oc_block; |
1455 | const bool mask_flag = true && jcp.oc_without_padding != jcp.oc |
1456 | && ocb == (nb_oc_block - 1); |
1457 | |
1458 | const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out); |
1459 | auto addr = EVEX_compress_addr(reg_out_ptr, off); |
1460 | |
1461 | const auto &p = attr_.post_ops_; |
1462 | const int sum_idx = p.find(primitive_kind::sum); |
1463 | const float *p_sum_scale = nullptr; |
1464 | const int32_t *p_sum_zp = nullptr; |
1465 | if (sum_idx != -1) { |
1466 | const auto &p_entry = p.entry_[sum_idx]; |
1467 | p_sum_scale = &p_entry.sum.scale; |
1468 | p_sum_zp = &p_entry.sum.zero_point; |
1469 | } |
1470 | |
1471 | if (p_sum_scale) { |
1472 | if (*p_sum_scale != 1.f) |
1473 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
1474 | if (*p_sum_zp != 0) |
1475 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
1476 | } |
1477 | |
1478 | int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * oc_block); |
1479 | if (jcp.with_bias) { |
1480 | int bias_offset = jcp.typesize_bia * ocb * oc_block; |
1481 | auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
1482 | cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); |
1483 | } |
1484 | if (compute_zp) { |
1485 | assert(jcp.req_zero_point_buffer); |
1486 | // add zero-point padding compensation when accum data is S32 |
1487 | const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag); |
1488 | vmovups(m_zmm_zp, |
1489 | EVEX_compress_addr(reg_zero_point_pbuff, |
1490 | get_zp_comp_offset(ocb, zp_h, zp_w))); |
1491 | const Zmm m_zmm_out = zmm_mask(zmm_out, mask_flag); |
1492 | vpaddd(m_zmm_out, zmm_out, zmm_zp); |
1493 | } |
1494 | if (jcp.src_zero_point) { |
1495 | // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32 |
1496 | int zp_offset = sizeof(int32_t) * ocb * oc_block; |
1497 | const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag); |
1498 | vpmulld(m_zmm_zp, zmm_src_zp, |
1499 | EVEX_compress_addr(reg_zp_compensation, zp_offset)); |
1500 | vpaddd(zmm_out, zmm_out, zmm_zp); |
1501 | } |
1502 | |
1503 | /* add bias and zero-point to zmm_accum */ |
1504 | vcvtdq2ps(zmm_out, zmm_out); |
1505 | const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag); |
1506 | vmulps(zmm_out_msk, zmm_out, |
1507 | EVEX_compress_addr(reg_ptr_scales, scale_offset)); |
1508 | if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias); |
1509 | |
1510 | apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag); |
1511 | |
1512 | if (jcp.dst_scale) { vmulps(zmm_out_msk, zmm_out, zmm_dst_scale); } |
1513 | if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); } |
1514 | |
1515 | // Properly saturate the accumulators for integer datatypes |
1516 | if (one_of(jcp.dst_dt, u8, s8, s32)) { |
1517 | init_saturate_f32( |
1518 | zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dst_dt); |
1519 | saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dst_dt); |
1520 | vcvtps2dq(zmm_out, zmm_out); |
1521 | } |
1522 | |
1523 | const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true); |
1524 | |
1525 | switch (jcp.dst_dt) { |
1526 | case data_type::f32: |
1527 | case data_type::s32: vmovups(addr, zmm_out_store); break; |
1528 | case data_type::bf16: |
1529 | store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag); |
1530 | break; |
1531 | case data_type::s8: vpmovsdb(addr, zmm_out_store); break; |
1532 | case data_type::u8: vpmovusdb(addr, zmm_out_store); break; |
1533 | default: assert(!"unknown dst_dt" ); |
1534 | } |
1535 | } |
1536 | |
1537 | void jit_avx512_core_amx_fwd_kernel_t::store_output_vector(const Zmm &zmm_out, |
1538 | int ocb, int h, int w, const bool compute_zp, const int zp_h, |
1539 | const int zp_w) { |
1540 | /* |
1541 | Output: |
1542 | jcp.is_nspc !jcp.is_nspc |
1543 | --------------------- --------------------- |
1544 | INT8: [N][H][W][NBOC][16OC] |
1545 | BF16: [N][H][W][NBOC][16OC] or [N][NBOC][H][W][16OC] |
1546 | */ |
1547 | if (jcp.src_dt == data_type::bf16) { |
1548 | store_output_vector_bf16(zmm_out, ocb, h, w); |
1549 | } else { |
1550 | store_output_vector_int8(zmm_out, ocb, h, w, compute_zp, zp_h, zp_w); |
1551 | } |
1552 | } |
1553 | |
1554 | void jit_avx512_core_amx_fwd_kernel_t::store_output(int width, int tail, |
1555 | bool do_store, const bool handle_h_blk, const int t_pad_output, |
1556 | const int b_pad_output, const int l_pad_output, const int r_pad_output, |
1557 | const bool is_last_oh_block, const bool zp_3d_pad) { |
1558 | auto store_output_block = [=](int width, int tail, bool do_store, |
1559 | bool is_last_h = false) { |
1560 | // Calculate the number of oh blocks; it may differ on last call |
1561 | const int last_h_blks |
1562 | = div_up(jcp.oh, jcp.oh_per_tile) % jcp.nb_oh_blocking; |
1563 | const int h_blks = is_last_h && last_h_blks != 0 ? last_h_blks |
1564 | : jcp.nb_oh_blocking; |
1565 | // Calculate the number of oh rows per tile; it may differ on last call |
1566 | const int h_tail = is_last_h && jcp.oh % jcp.oh_per_tile != 0 |
1567 | ? (h_blks - 1) * jcp.oh_per_tile + jcp.oh % jcp.oh_per_tile |
1568 | : h_blks * jcp.oh_per_tile; |
1569 | const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1; |
1570 | const int owp = gen_kw + jcp.ow - 1; |
1571 | |
1572 | if (jcp.dst_scale) { |
1573 | mov(reg_dst_scale, ptr[param1 + GET_OFF(dst_scale)]); |
1574 | vmovups(zmm_dst_scale, EVEX_compress_addr(reg_dst_scale, 0)); |
1575 | } |
1576 | if (jcp.src_zero_point) { |
1577 | mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]); |
1578 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
1579 | vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0)); |
1580 | } |
1581 | if (jcp.dst_zero_point) { |
1582 | mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]); |
1583 | vcvtdq2ps(zmm_dst_zp, |
1584 | EVEX_compress_addr(reg_dst_zero_point, 0, true)); |
1585 | } |
1586 | |
1587 | for_(int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) |
1588 | for (int ohb = 0; ohb < h_blks; ohb++) { |
1589 | /* Formats: Workspace: [NBOC][W][16OC] */ |
1590 | tilestored(ptr[reg_wsp_ptr + reg_wei_stride |
1591 | + get_wsp_ocb_offset(ohb, ocb)], |
1592 | Tmm(get_out_tensor(ohb, ocb, tail))); |
1593 | is_buffer_empty_ = false; |
1594 | is_store_done_ = false; |
1595 | |
1596 | // preserve registers used by binary post_ops injector |
1597 | const injector_utils::conditional_register_preserve_guard_t |
1598 | cond_register_guard(jcp.with_binary, this, |
1599 | {bin_injector_helper_reg_1, |
1600 | bin_injector_helper_reg_2, |
1601 | bin_injector_helper_reg_3}); |
1602 | |
1603 | for (int tw = 0; tw < width && do_store; tw++) { |
1604 | // height |
1605 | const int oh_index = ohb * jcp.oh_per_tile + tw / owp; |
1606 | const bool zp_h_pad |
1607 | = oh_index < t_pad_output || oh_index >= b_pad_output; |
1608 | const int zp_h = get_zp_index_offset( |
1609 | oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output); |
1610 | // width |
1611 | const int ow_index = tw % owp; |
1612 | const bool zp_w_pad |
1613 | = ow_index < l_pad_output || ow_index >= r_pad_output; |
1614 | const int zp_w = get_zp_index_offset( |
1615 | ow_index, (int)jcp.ow_mid, l_pad_output, r_pad_output); |
1616 | |
1617 | const bool compute_zp = jcp.req_zero_point_buffer |
1618 | && (zp_3d_pad || zp_w_pad || zp_h_pad); |
1619 | |
1620 | assert(IMPLICATION(jcp.oh_per_tile == 1, |
1621 | ohb == oh_index && tw == ow_index)); |
1622 | if (oh_index < h_tail && ow_index < jcp.ow) { |
1623 | Zmm zmm_r = zmm_out(tw); |
1624 | vmovups(zmm_r, |
1625 | ptr[reg_wsp_ptr |
1626 | + get_wsp_row_offset(ohb, ocb, tw)]); |
1627 | store_output_vector(zmm_r, ocb, oh_index, ow_index, |
1628 | compute_zp, zp_h, zp_w); |
1629 | } |
1630 | } |
1631 | } |
1632 | }; |
1633 | |
1634 | // adjustment in case interleave store is turned off |
1635 | do_store = do_store || jcp.per_one_pstore == 0; |
1636 | if (!do_store) { w_padding.emplace(l_pad_output, r_pad_output); } |
1637 | if (!handle_h_blk) { |
1638 | store_output_block(width, tail, do_store, is_last_oh_block); |
1639 | } else { |
1640 | if (jcp.oh % (jcp.oh_per_tile * jcp.nb_oh_blocking) == 0) { |
1641 | store_output_block(width, tail, do_store); |
1642 | } else { |
1643 | Label label_oh_oc_store, label_done; |
1644 | mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]); |
1645 | cmp(reg_last_h, 0); |
1646 | jne(label_oh_oc_store, T_NEAR); |
1647 | store_output_block(width, tail, do_store, true); // last h |
1648 | jmp(label_done, T_NEAR); |
1649 | L(label_oh_oc_store); |
1650 | store_output_block(width, tail, do_store, false); |
1651 | L(label_done); |
1652 | } |
1653 | } |
1654 | if (do_store) { |
1655 | add(reg_out_ptr, get_out_shift(width, jcp.typesize_out)); |
1656 | if (jcp.req_zero_point_buffer) { |
1657 | const size_t sp_shift |
1658 | = accum_with_upper_bound(width, l_pad_output, r_pad_output); |
1659 | add(reg_zero_point_pbuff, get_out_shift(sp_shift, sizeof(int32_t))); |
1660 | } |
1661 | } |
1662 | } |
1663 | |
1664 | void jit_avx512_core_amx_fwd_kernel_t::interleave_store(int width, |
1665 | int const t_pad_output, int const b_pad_output, const bool zp_3d_pad) { |
1666 | for (int c = 0; |
1667 | c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_; |
1668 | c++) { |
1669 | // row_count = ohb * OCB * TW + ocb * TW + tw |
1670 | int tw = row_count_ % prv_width_; |
1671 | int ocb = (row_count_ / prv_width_) % jcp.nb_oc_blocking; |
1672 | int ohb = (row_count_ / prv_width_) / jcp.nb_oc_blocking; |
1673 | |
1674 | // preserve registers used by binary post_ops injector |
1675 | const injector_utils::conditional_register_preserve_guard_t |
1676 | cond_register_guard(jcp.with_binary, this, |
1677 | {bin_injector_helper_reg_1, bin_injector_helper_reg_2}); |
1678 | |
1679 | // height |
1680 | const int oh_index = ohb; |
1681 | const bool zp_h_pad |
1682 | = oh_index < t_pad_output || oh_index >= b_pad_output; |
1683 | const int zp_h = get_zp_index_offset( |
1684 | oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output); |
1685 | // width |
1686 | const int l_pad_output |
1687 | = w_padding.empty() ? 0 : w_padding.front().l_pad_output; |
1688 | const int r_pad_output |
1689 | = w_padding.empty() ? jcp.ow : w_padding.front().r_pad_output; |
1690 | |
1691 | const bool zp_w_pad = tw < l_pad_output || tw >= r_pad_output; |
1692 | const int zp_w = get_zp_index_offset( |
1693 | tw, (int)jcp.ow_mid, l_pad_output, r_pad_output); |
1694 | |
1695 | const bool compute_zp = jcp.req_zero_point_buffer |
1696 | && (zp_3d_pad || zp_w_pad || zp_h_pad); |
1697 | |
1698 | Zmm zmm_r = zmm_out(tw); |
1699 | vmovups(zmm_r, ptr[reg_wsp_ptr + get_wsp_row_offset(ohb, ocb, tw)]); |
1700 | store_output_vector(zmm_r, ocb, ohb, tw, compute_zp, zp_h, zp_w); |
1701 | row_count_++; |
1702 | |
1703 | if (row_count_ |
1704 | == prv_width_ * jcp.nb_oc_blocking * jcp.nb_oh_blocking) { |
1705 | add(reg_out_ptr, get_out_shift(prv_width_, jcp.typesize_out)); |
1706 | if (jcp.req_zero_point_buffer) { |
1707 | const size_t sp_shift = accum_with_upper_bound( |
1708 | prv_width_, l_pad_output, r_pad_output); |
1709 | add(reg_zero_point_pbuff, |
1710 | get_out_shift(sp_shift, sizeof(int32_t))); |
1711 | if (!w_padding.empty()) w_padding.pop(); |
1712 | } |
1713 | row_count_ = 0; |
1714 | is_store_done_ = true; |
1715 | prv_width_ = width; |
1716 | } |
1717 | } |
1718 | } |
1719 | |
1720 | void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width, |
1721 | bool do_store, const bool handle_h_blk, const int t_pad_output, |
1722 | const int b_pad_output, const int l_pad_output, const int r_pad_output, |
1723 | const bool zp_3d_pad, const bool is_last_oh_block) { |
1724 | const bool tail = width == jcp.tile_tail; |
1725 | |
1726 | auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) { |
1727 | if (jcp.src_dt == data_type::bf16 && jcp.wei_dt == data_type::bf16) { |
1728 | tdpbf16ps(x1, x2, x3); |
1729 | } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::u8) { |
1730 | tdpbuud(x1, x2, x3); |
1731 | } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::s8) { |
1732 | tdpbusd(x1, x2, x3); |
1733 | } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::u8) { |
1734 | tdpbsud(x1, x2, x3); |
1735 | } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::s8) { |
1736 | tdpbssd(x1, x2, x3); |
1737 | } else { |
1738 | assert(!"unsupported combination" ); |
1739 | } |
1740 | }; |
1741 | |
1742 | prepare_output(tail); |
1743 | |
1744 | // prepare registers for when 'interleave_store()' is computed |
1745 | if (jcp.dst_scale) { |
1746 | mov(reg_dst_scale, ptr[param1 + GET_OFF(dst_scale)]); |
1747 | vmovups(zmm_dst_scale, EVEX_compress_addr(reg_dst_scale, 0)); |
1748 | } |
1749 | if (jcp.src_zero_point) { |
1750 | mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]); |
1751 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
1752 | vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0)); |
1753 | } |
1754 | if (jcp.dst_zero_point) { |
1755 | mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]); |
1756 | vcvtdq2ps(zmm_dst_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true)); |
1757 | } |
1758 | |
1759 | // reduced lowering path |
1760 | if (jcp.is_relo) { |
1761 | const int nreduce = jcp.nreduce; |
1762 | const int stride = jcp.ic_block_int; // ie 64 (32) for int8 (bf16) |
1763 | |
1764 | push(reg_inp_ptr); |
1765 | push(reg_wei_ptr); |
1766 | |
1767 | for (int ireduce = 0; ireduce < nreduce; ireduce += stride) { |
1768 | for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) { |
1769 | tileloadd(Tmm(get_inp_tensor(ohb, tail)), |
1770 | ptr[reg_inp_ptr + get_inp_offset(ohb, 0) |
1771 | + reg_inp_stride]); |
1772 | } |
1773 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
1774 | tileloadd(Tmm(get_wei_tensor(ocb)), |
1775 | ptr[reg_wei_ptr + get_wei_offset(ocb, 0) |
1776 | + reg_wei_stride]); |
1777 | for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) { |
1778 | tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)), |
1779 | Tmm(get_inp_tensor(ohb, tail)), |
1780 | Tmm(get_wei_tensor(ocb))); |
1781 | interleave_store(width, t_pad_output, b_pad_output); |
1782 | } |
1783 | } |
1784 | if (ireduce + stride < nreduce) { |
1785 | add(reg_inp_ptr, stride * jcp.typesize_in); |
1786 | add(reg_wei_ptr, stride * jcp.oc_block * jcp.typesize_in); |
1787 | } |
1788 | } |
1789 | pop(reg_wei_ptr); |
1790 | pop(reg_inp_ptr); |
1791 | |
1792 | store_output(width, tail, do_store, handle_h_blk, t_pad_output, |
1793 | b_pad_output, l_pad_output, r_pad_output, is_last_oh_block); |
1794 | |
1795 | add(reg_inp_ptr, get_inp_shift()); |
1796 | |
1797 | return; |
1798 | } |
1799 | |
1800 | auto wei_offset = [&](int icb, int ocb, int kd, int kh, int kw) { |
1801 | return (size_t)icb * get_wei_icb_step() + kd * get_wei_d_step() |
1802 | + kh * get_wei_h_step() + get_wei_offset(ocb, kw); |
1803 | }; |
1804 | |
1805 | auto inp_offset = [&](int icb, int ohb, int kd, int kh, int kw) { |
1806 | return (size_t)icb * get_inp_icb_step() + kd * get_inp_d_step() |
1807 | + kh * get_inp_h_step() + get_inp_offset(ohb, kw); |
1808 | }; |
1809 | |
1810 | auto safe_tileloadd |
1811 | = [=](const Tmm &t1, const Xbyak::Reg64 ®_ptr, size_t offset, |
1812 | const Xbyak::Reg64 ®_stride) { |
1813 | if (offset <= INT32_MAX) { |
1814 | tileloadd(t1, ptr[reg_ptr + offset + reg_stride]); |
1815 | } else { |
1816 | safe_add(reg_ptr, offset, reg_tmp); |
1817 | tileloadd(t1, ptr[reg_ptr + reg_stride]); |
1818 | safe_sub(reg_ptr, offset, reg_tmp); |
1819 | } |
1820 | }; |
1821 | |
1822 | // normal and k-remainders path |
1823 | const bool check_kd_padding |
1824 | = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0); |
1825 | for (int icb = 0; icb < jcp.nb_ic_int; icb++) { |
1826 | Label kd_skip_compute; |
1827 | if (check_kd_padding) mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]); |
1828 | |
1829 | for (int kd = 0; kd < jcp.kd; kd++) { |
1830 | if (check_kd_padding) { |
1831 | dec(reg_kd); |
1832 | jl(kd_skip_compute, T_NEAR); |
1833 | push(reg_kd); |
1834 | } |
1835 | for (int kh = 0; kh < jcp.kh; kh++) { |
1836 | for (int set_idx = 0; set_idx < jcp.n_stride_sets; |
1837 | set_idx++) { // used to optimize input memory reuse in L1$ |
1838 | for (int kw = set_idx; kw < jcp.kw; kw += jcp.kw_step) { |
1839 | for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) { |
1840 | const size_t inp_off |
1841 | = inp_offset(icb, ohb, kd, kh, kw); |
1842 | safe_tileloadd(Tmm(get_inp_tensor(ohb, tail)), |
1843 | reg_inp_ptr, inp_off, reg_inp_stride); |
1844 | } |
1845 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
1846 | const size_t wei_off |
1847 | = wei_offset(icb, ocb, kd, kh, kw); |
1848 | safe_tileloadd(Tmm(get_wei_tensor(ocb)), |
1849 | reg_wei_ptr, wei_off, reg_wei_stride); |
1850 | for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) { |
1851 | tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)), |
1852 | Tmm(get_inp_tensor(ohb, tail)), |
1853 | Tmm(get_wei_tensor(ocb))); |
1854 | interleave_store(width, t_pad_output, |
1855 | b_pad_output, zp_3d_pad); |
1856 | } |
1857 | } |
1858 | } |
1859 | } |
1860 | } |
1861 | if (check_kd_padding) pop(reg_kd); |
1862 | } |
1863 | L(kd_skip_compute); |
1864 | } |
1865 | |
1866 | store_output(width, tail, do_store, handle_h_blk, t_pad_output, |
1867 | b_pad_output, l_pad_output, r_pad_output, is_last_oh_block, |
1868 | zp_3d_pad); |
1869 | |
1870 | add(reg_inp_ptr, get_inp_shift()); |
1871 | } |
1872 | |
1873 | void jit_avx512_core_amx_fwd_kernel_t::dispatch_icb_loop(int width, |
1874 | bool do_store, const int l_pad_output, const int r_pad_output, |
1875 | const bool zp_3d_pad) { |
1876 | if (jcp.req_zero_point_buffer |
1877 | && (jcp.t_pad_output > 0 || jcp.b_pad_output > 0)) { |
1878 | const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile; |
1879 | const size_t height_limit = reduce_to_blocked_dims( |
1880 | jcp.oh, oh_step_size, jcp.t_pad_output, jcp.b_pad_output); |
1881 | const int ur_h = div_up(height_limit, oh_step_size); |
1882 | assert(6 >= ur_h); |
1883 | |
1884 | // Use a jump-table to execute the corresponding block |
1885 | Label h_blk_label[6], h_blk_end_label, jmp_table_label; |
1886 | mov(reg_jmp_blk, ptr[param1 + GET_OFF(ohb)]); |
1887 | mov(reg_tmp, jmp_table_label); |
1888 | jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]); |
1889 | jmp(h_blk_end_label, T_NEAR); // error, shouldn't happen |
1890 | |
1891 | align(8); |
1892 | L(jmp_table_label); |
1893 | for (int u = 0; u < ur_h; ++u) { |
1894 | putL(h_blk_label[u]); |
1895 | } |
1896 | |
1897 | // Save value of global variables for the next 'h_blk' iteration |
1898 | const int local_prv_width = prv_width_; |
1899 | const int local_row_count = row_count_; |
1900 | const bool local_is_store_done = is_store_done_; |
1901 | const bool local_is_buffer_empty = is_buffer_empty_; |
1902 | |
1903 | // Unroll ow_block with regards to l_pad_output and r_pad_output |
1904 | int cur_t_pad = reduce_to_block(oh_step_size, jcp.t_pad_output); |
1905 | int cur_b_pad = height_limit |
1906 | - reduce_to_block(oh_step_size, jcp.b_pad_output); |
1907 | for (int u = 0; u < ur_h; u++) { |
1908 | bool last = u == ur_h - 1; |
1909 | L(h_blk_label[u]); |
1910 | |
1911 | // restore to previous 'h_blk' state of variables |
1912 | prv_width_ = local_prv_width; |
1913 | row_count_ = local_row_count; |
1914 | is_store_done_ = local_is_store_done; |
1915 | is_buffer_empty_ = local_is_buffer_empty; |
1916 | compute_icb_loop(width, do_store, false, cur_t_pad, cur_b_pad, |
1917 | l_pad_output, r_pad_output, zp_3d_pad, last); |
1918 | cur_t_pad = nstl::max(0, cur_t_pad - oh_step_size); |
1919 | cur_b_pad = nstl::max(0, cur_b_pad - oh_step_size); |
1920 | if (!last) jmp(h_blk_end_label, T_NEAR); |
1921 | } |
1922 | L(h_blk_end_label); |
1923 | } else { |
1924 | compute_icb_loop(width, do_store, true, 0, jcp.oh, l_pad_output, |
1925 | r_pad_output, zp_3d_pad); |
1926 | } |
1927 | } |
1928 | |
1929 | void jit_avx512_core_amx_fwd_kernel_t::dispatch_zp_3d_compute(int width, |
1930 | bool do_store, const int l_pad_output, const int r_pad_output) { |
1931 | if (jcp.req_zero_point_buffer && (jcp.f_pad > 0 || jcp.back_pad > 0)) { |
1932 | Label compute_3d_zp_label, zp_d_end_label; |
1933 | mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]); |
1934 | cmp(reg_kd, jcp.kd); |
1935 | jne(compute_3d_zp_label, T_NEAR); |
1936 | |
1937 | // Save value of global variables for next 'dispatch_icb_loop' |
1938 | const int local_prv_width = prv_width_; |
1939 | const int local_row_count = row_count_; |
1940 | const bool local_is_store_done = is_store_done_; |
1941 | const bool local_is_buffer_empty = is_buffer_empty_; |
1942 | dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false); |
1943 | |
1944 | jmp(zp_d_end_label, T_NEAR); |
1945 | L(compute_3d_zp_label); |
1946 | |
1947 | prv_width_ = local_prv_width; |
1948 | row_count_ = local_row_count; |
1949 | is_store_done_ = local_is_store_done; |
1950 | is_buffer_empty_ = local_is_buffer_empty; |
1951 | dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, true); |
1952 | |
1953 | L(zp_d_end_label); |
1954 | } else |
1955 | dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false); |
1956 | } |
1957 | |
1958 | void jit_avx512_core_amx_fwd_kernel_t::compute_ow_loop() { |
1959 | auto compute_ow_loop_body = [=](bool last_owb, int num_tile_blocks, |
1960 | const int l_pad_output, |
1961 | const int r_pad_output) { |
1962 | int cur_l_pad_output = l_pad_output; |
1963 | int cur_r_pad_output = r_pad_output; |
1964 | int gen_tile_tail = last_owb && jcp.tile_tail > 0 ? jcp.tile_tail |
1965 | : jcp.tile_width; |
1966 | init_runtime_counters(last_owb && num_tile_blocks == 1); |
1967 | for (int owb = 0; owb < num_tile_blocks - 1; owb++) { |
1968 | dispatch_zp_3d_compute( |
1969 | jcp.tile_width, false, cur_l_pad_output, cur_r_pad_output); |
1970 | cur_l_pad_output = nstl::max(0, cur_l_pad_output - jcp.tile_width); |
1971 | cur_r_pad_output = nstl::max(0, cur_r_pad_output - jcp.tile_width); |
1972 | } |
1973 | dispatch_zp_3d_compute( |
1974 | gen_tile_tail, true, cur_l_pad_output, cur_r_pad_output); |
1975 | }; |
1976 | |
1977 | assert(jcp.nb_ow > 0); |
1978 | if (jcp.nb_ow == 1) { |
1979 | const int ow_r_pad_start |
1980 | = nstl::max(jcp.ow - jcp.r_pad_output, jcp.l_pad_output); |
1981 | compute_ow_loop_body( |
1982 | true, jcp.ow_blocks, jcp.l_pad_output, ow_r_pad_start); |
1983 | } else if (jcp.req_zero_point_buffer |
1984 | && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0)) { |
1985 | |
1986 | const size_t zp_addr_shift |
1987 | = jcp.ngroups * jcp.oc_without_padding * sizeof(int32_t); |
1988 | const int ow_step_size = jcp.ow_block; |
1989 | const int ow_blocks_per_call = div_up(ow_step_size, jcp.tile_width); |
1990 | const int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call == 0 |
1991 | ? ow_blocks_per_call |
1992 | : jcp.ow_blocks % ow_blocks_per_call; |
1993 | const int width_limit = reduce_to_blocked_dims( |
1994 | jcp.ow, ow_step_size, jcp.l_pad_output, jcp.r_pad_output); |
1995 | const int ur_w = div_up(width_limit, ow_step_size); |
1996 | assert(6 >= ur_w); |
1997 | // Use a jump-table to execute the corresponding block |
1998 | Label w_blk_label[6], w_blk_end_label, jmp_table_label; |
1999 | mov(reg_jmp_blk, ptr[param1 + GET_OFF(owb)]); |
2000 | mov(reg_tmp, jmp_table_label); |
2001 | jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]); |
2002 | jmp(w_blk_end_label, T_NEAR); // error, shouldn't happen |
2003 | |
2004 | align(8); |
2005 | L(jmp_table_label); |
2006 | for (int u = 0; u < ur_w; ++u) { |
2007 | putL(w_blk_label[u]); |
2008 | } |
2009 | |
2010 | // Unroll ow_block with regards to l_pad_output and r_pad_output |
2011 | int cur_l_pad = reduce_to_block(ow_step_size, jcp.l_pad_output); |
2012 | int cur_r_pad |
2013 | = width_limit - reduce_to_block(ow_step_size, jcp.r_pad_output); |
2014 | int zp_offset = 0; |
2015 | for (int u = 0; u < ur_w; u++) { |
2016 | const bool last = u == ur_w - 1; |
2017 | L(w_blk_label[u]); |
2018 | if (u > 0) add(reg_zero_point_pbuff, zp_offset * zp_addr_shift); |
2019 | compute_ow_loop_body(last, |
2020 | last ? last_owb_tile_blocks : ow_blocks_per_call, cur_l_pad, |
2021 | cur_r_pad); |
2022 | zp_offset += accum_with_upper_bound( |
2023 | ow_step_size, cur_l_pad, cur_r_pad); |
2024 | cur_l_pad = nstl::max(0, cur_l_pad - ow_step_size); |
2025 | cur_r_pad = nstl::max(0, cur_r_pad - ow_step_size); |
2026 | if (!last) jmp(w_blk_end_label, T_NEAR); |
2027 | } |
2028 | L(w_blk_end_label); |
2029 | |
2030 | } else { |
2031 | assert(jcp.oh_per_tile == 1); |
2032 | Label label_done; |
2033 | int ow_blocks_per_call = utils::div_up(jcp.ow_block, jcp.tile_width); |
2034 | int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call; |
2035 | if (last_owb_tile_blocks == 0 && jcp.tile_tail > 0) |
2036 | last_owb_tile_blocks = ow_blocks_per_call; |
2037 | if (last_owb_tile_blocks > 0) { |
2038 | Label label_not_last_owb; |
2039 | mov(reg_tmp, ptr[param1 + GET_OFF(owb)]); |
2040 | cmp(reg_tmp, jcp.nb_ow - 1); |
2041 | jne(label_not_last_owb, T_NEAR); |
2042 | |
2043 | compute_ow_loop_body(true, last_owb_tile_blocks, 0, jcp.ow); |
2044 | |
2045 | jmp(label_done, T_NEAR); |
2046 | |
2047 | L(label_not_last_owb); |
2048 | } |
2049 | compute_ow_loop_body(false, ow_blocks_per_call, 0, jcp.ow); |
2050 | |
2051 | L(label_done); |
2052 | } |
2053 | } |
2054 | |
2055 | void jit_avx512_core_amx_fwd_kernel_t::generate() { |
2056 | preamble(); |
2057 | |
2058 | mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]); |
2059 | mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]); |
2060 | mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]); |
2061 | mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]); |
2062 | if (jcp.req_zero_point_buffer) |
2063 | mov(reg_zero_point_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]); |
2064 | |
2065 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
2066 | mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); |
2067 | |
2068 | const int fac = jcp.is_relo ? jcp.stride_w * jcp.kh |
2069 | : jcp.is_pbuffer_strided ? 1 : jcp.stride_w; |
2070 | const int inp_stride = fac * jcp.ic_block_int_np * jcp.typesize_in; |
2071 | const int wei_stride = jcp.oc_block * jcp.typesize_acc; |
2072 | mov(reg_inp_stride, inp_stride); |
2073 | mov(reg_wei_stride, wei_stride); |
2074 | |
2075 | if (jcp.is_nspc && jcp.oc_without_padding != jcp.oc) { |
2076 | // Use mask 0xF by default for all output data and post-ops |
2077 | // loads / stores with block index |
2078 | // ocb = occ * jcp.nb_oc_blocking + (jcp.nb_oc_blocking - 1) |
2079 | // TODO: use masked loads / stores for the last occ only |
2080 | int current_block_size = jcp.oc_block; |
2081 | int mask = (1 << current_block_size) - 1; |
2082 | Xbyak::Reg32 regw_tmp = reg_tmp.cvt32(); |
2083 | mov(regw_tmp, mask); |
2084 | kmovw(ktail_mask, regw_tmp); |
2085 | Xbyak::Label mask_is_set; |
2086 | mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); |
2087 | cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); |
2088 | jne(mask_is_set, T_NEAR); |
2089 | // Reset the mask |
2090 | current_block_size = jcp.oc_without_padding % jcp.oc_block; |
2091 | mask = (1 << current_block_size) - 1; |
2092 | mov(regw_tmp, mask); |
2093 | kmovw(ktail_mask, regw_tmp); |
2094 | |
2095 | L(mask_is_set); |
2096 | } |
2097 | compute_ow_loop(); |
2098 | |
2099 | postamble(); |
2100 | |
2101 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
2102 | } |
2103 | |
2104 | void jit_avx512_core_amx_fwd_kernel_t::tile_configure(char *tcfg_buff) { |
2105 | const int vnni_width = jcp.src_dt == data_type::bf16 ? 2 : 4; |
2106 | // Input tile dimensions |
2107 | const int a_col = jcp.is_relo ? jcp.ic_block_int |
2108 | : jcp.ic_block_int_np * jcp.kw_per_tile; |
2109 | // Weights tile dimensions |
2110 | const int b_col = jcp.oc_block * vnni_width; |
2111 | const int b_row = a_col / vnni_width; |
2112 | // Accumulator tile dimensions |
2113 | const int c_col = 16; |
2114 | |
2115 | for (size_t i = 0; i < 64; i++) |
2116 | tcfg_buff[i] = 0; |
2117 | |
2118 | // Weights (W_BASE) Tensor Tiles |
2119 | for (int i = 0; i < jcp.nb_oc_blocking; i++) |
2120 | tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i), |
2121 | b_row, b_col * jcp.typesize_in); |
2122 | |
2123 | // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles |
2124 | for (int h = 0; h < jcp.nb_oh_blocking; h++) { |
2125 | tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h), |
2126 | jcp.tile_width, a_col * jcp.typesize_in); |
2127 | for (int i = 0; i < jcp.nb_oc_blocking; i++) |
2128 | tc_configure_tile((palette_config_t *)tcfg_buff, |
2129 | get_out_tensor(h, i), jcp.tile_width, |
2130 | c_col * jcp.typesize_acc); |
2131 | } |
2132 | if (jcp.tile_tail != 0) { |
2133 | assert(jcp.nb_oh_blocking == 1); |
2134 | assert(jcp.oh_per_tile == 1); |
2135 | assert(jcp.ow > jcp.tile_width); |
2136 | tc_configure_tile((palette_config_t *)tcfg_buff, |
2137 | get_inp_tensor(0, true), jcp.tile_tail, |
2138 | a_col * jcp.typesize_in); |
2139 | for (int i = 0; i < jcp.nb_oc_blocking; i++) |
2140 | tc_configure_tile((palette_config_t *)tcfg_buff, |
2141 | get_out_tensor(0, i, true), jcp.tile_tail, |
2142 | c_col * jcp.typesize_acc); |
2143 | } |
2144 | |
2145 | ((palette_config_t *)tcfg_buff)->palette_id = amx::get_target_palette(); |
2146 | } |
2147 | |
2148 | void jit_avx512_core_amx_fwd_kernel_t::set_oh_blk_limits(jit_conv_conf_t &jcp) { |
2149 | |
2150 | constexpr int size = sizeof(jcp.h_blk_limits) / sizeof(jcp.h_blk_limits[0]); |
2151 | // set default values |
2152 | for (int i = 0; i < size; i++) |
2153 | jcp.h_blk_limits[i] = jcp.oh; |
2154 | |
2155 | const bool calculate_oh_limits |
2156 | = jcp.t_pad_output > 0 || jcp.b_pad_output > 0; |
2157 | if (jcp.req_zero_point_buffer && calculate_oh_limits) { |
2158 | |
2159 | int limit_idx = 0; |
2160 | const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile; |
2161 | |
2162 | // full t_pad output block |
2163 | const int t_pad_blk_end = rnd_dn(jcp.t_pad_output, oh_step_size); |
2164 | if (jcp.t_pad_output >= oh_step_size) { |
2165 | jcp.h_blk_limits[limit_idx++] = t_pad_blk_end; |
2166 | } |
2167 | // t_pad output overlap with no padding |
2168 | const int t_pad_shift = jcp.t_pad_output % oh_step_size; |
2169 | if (t_pad_shift != 0) { |
2170 | jcp.h_blk_limits[limit_idx++] = t_pad_blk_end + t_pad_shift; |
2171 | } |
2172 | const int t_pad_next_blk = rnd_up(jcp.t_pad_output, oh_step_size); |
2173 | const int oh_blk_tail = jcp.oh % oh_step_size; |
2174 | const int b_pad_no_tail = nstl::max(0, jcp.b_pad_output - oh_blk_tail); |
2175 | const int b_pad_start |
2176 | = nstl::max(jcp.t_pad_output, jcp.oh - jcp.b_pad_output); |
2177 | const int b_pad_blk_start = rnd_dn(b_pad_start, oh_step_size); |
2178 | // middle block without padding |
2179 | const int mid_blk = nstl::max(0, b_pad_blk_start - t_pad_next_blk); |
2180 | if (mid_blk >= oh_step_size) { |
2181 | jcp.h_blk_limits[limit_idx++] = b_pad_blk_start; |
2182 | } |
2183 | // no padding with b_pad overlap |
2184 | const int b_pad_shift = b_pad_no_tail % oh_step_size; |
2185 | if (b_pad_shift != 0) { |
2186 | jcp.h_blk_limits[limit_idx++] = rnd_up(b_pad_start, oh_step_size); |
2187 | } |
2188 | // full b_pad output block |
2189 | if (b_pad_no_tail >= oh_step_size) { |
2190 | jcp.h_blk_limits[limit_idx++] = jcp.oh - oh_blk_tail; |
2191 | } |
2192 | // b_pad tail block does not require a limit |
2193 | } |
2194 | } |
2195 | |
2196 | void jit_avx512_core_amx_fwd_kernel_t::set_ow_blk_limits(jit_conv_conf_t &jcp) { |
2197 | |
2198 | jcp.l_pad_blk = 0; |
2199 | jcp.no_pad_w_blk = 0; |
2200 | jcp.r_pad_blk = 0; |
2201 | |
2202 | const bool calculate_ow_limits |
2203 | = jcp.nb_ow > 1 && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0); |
2204 | if (jcp.req_zero_point_buffer && calculate_ow_limits) { |
2205 | const int ow_step_size = jcp.ow_block; |
2206 | |
2207 | // l_pad |
2208 | const int l_pad_limit |
2209 | = (jcp.l_pad_output >= ow_step_size ? ow_step_size : 0) |
2210 | + (jcp.l_pad_output % ow_step_size); |
2211 | const int l_pad_area_blk = rnd_up(l_pad_limit, ow_step_size); |
2212 | jcp.l_pad_blk = div_up(l_pad_limit, ow_step_size); |
2213 | |
2214 | // middle (area without padding) |
2215 | const int no_pad_area |
2216 | = nstl::max(0, jcp.ow - l_pad_area_blk - jcp.r_pad_output); |
2217 | jcp.no_pad_w_blk = no_pad_area >= ow_step_size ? 1 : 0; |
2218 | |
2219 | // r_pad |
2220 | const int no_pad_area_shift = no_pad_area % ow_step_size; |
2221 | const int r_pad_area_overlap |
2222 | = no_pad_area_shift == 0 ? 0 : ow_step_size - no_pad_area_shift; |
2223 | const int r_pad_area |
2224 | = nstl::max(0, jcp.r_pad_output - r_pad_area_overlap); |
2225 | const int r_pad_limit = (r_pad_area >= ow_step_size ? ow_step_size : 0) |
2226 | + (r_pad_area % ow_step_size); |
2227 | jcp.r_pad_blk = (r_pad_area_overlap > 0 ? 1 : 0) |
2228 | + div_up(r_pad_limit, ow_step_size); |
2229 | } |
2230 | } |
2231 | |
2232 | status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, |
2233 | const convolution_desc_t &cd, memory_desc_t &src_md, |
2234 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
2235 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) { |
2236 | using namespace prop_kind; |
2237 | |
2238 | const memory_desc_wrapper src_d(&src_md); |
2239 | const memory_desc_wrapper weights_d(&weights_md); |
2240 | const memory_desc_wrapper dst_d(&dst_md); |
2241 | const memory_desc_wrapper bias_d(&bias_md); |
2242 | |
2243 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
2244 | int ndims = src_d.ndims(); |
2245 | bool is_1d = ndims == 3; |
2246 | bool is_3d = ndims == 5; |
2247 | |
2248 | const bool is_bf16_convolution |
2249 | = everyone_is(true, src_d.data_type() == data_type::bf16, |
2250 | weights_d.data_type() == data_type::bf16, |
2251 | one_of(dst_d.data_type(), data_type::bf16, data_type::f32)); |
2252 | const bool is_int8_convolution = everyone_is(true, |
2253 | (src_d.data_type() == data_type::u8 |
2254 | || src_d.data_type() == data_type::s8), |
2255 | weights_d.data_type() == data_type::s8, |
2256 | one_of(dst_d.data_type(), data_type::f32, data_type::s32, |
2257 | data_type::s8, data_type::u8, data_type::bf16)); |
2258 | |
2259 | bool supported = mayiuse(avx512_core_amx) |
2260 | && (is_bf16_convolution || is_int8_convolution); |
2261 | if (!supported) return status::unimplemented; |
2262 | |
2263 | jcp = zero<decltype(jcp)>(); |
2264 | jcp.isa = avx512_core_amx; |
2265 | jcp.ndims = ndims; |
2266 | jcp.prop_kind = cd.prop_kind; |
2267 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
2268 | |
2269 | jcp.mb = src_d.dims()[0]; |
2270 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
2271 | jcp.oc_without_padding = jcp.oc; |
2272 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
2273 | jcp.ic_without_padding = jcp.ic; |
2274 | jcp.id = is_3d ? src_d.dims()[2] : 1; |
2275 | jcp.ih = !is_1d ? src_d.dims()[ndims - 2] : 1; |
2276 | jcp.iw = src_d.dims()[ndims - 1]; |
2277 | jcp.od = is_3d ? dst_d.dims()[2] : 1; |
2278 | jcp.oh = !is_1d ? dst_d.dims()[ndims - 2] : 1; |
2279 | jcp.ow = dst_d.dims()[ndims - 1]; |
2280 | jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; |
2281 | jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1; |
2282 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
2283 | jcp.f_pad = is_3d ? cd.padding[0][0] : 0; |
2284 | jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0; |
2285 | jcp.l_pad = cd.padding[0][ndims - 3]; |
2286 | jcp.stride_d = is_3d ? cd.strides[0] : 1; |
2287 | jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1; |
2288 | jcp.stride_w = cd.strides[ndims - 3]; |
2289 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
2290 | |
2291 | jcp.dilate_d = is_3d ? cd.dilates[ndims - 5] : 0; |
2292 | jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0; |
2293 | jcp.dilate_w = cd.dilates[ndims - 3]; |
2294 | |
2295 | const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1; |
2296 | const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1; |
2297 | const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1; |
2298 | jcp.back_pad = calculate_end_padding( |
2299 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, gen_kd); |
2300 | jcp.b_pad = calculate_end_padding( |
2301 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh); |
2302 | jcp.r_pad = calculate_end_padding( |
2303 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw); |
2304 | if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh |
2305 | || jcp.b_pad >= gen_kh || jcp.f_pad >= gen_kd |
2306 | || jcp.back_pad >= gen_kd) |
2307 | return status::unimplemented; |
2308 | |
2309 | const int max_pad = 28; // akin to maximum jcp.ur_w value in other jits |
2310 | if (jcp.l_pad > max_pad || jcp.r_pad > max_pad) |
2311 | return status::unimplemented; // TODO: relax this restriction |
2312 | |
2313 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
2314 | jcp.dst_dt = cd.dst_desc.data_type; |
2315 | jcp.src_dt = cd.src_desc.data_type; |
2316 | jcp.wei_dt = cd.weights_desc.data_type; |
2317 | |
2318 | jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); |
2319 | |
2320 | if (jcp.is_depthwise) |
2321 | return status::unimplemented; // TODO: add support of DW convolution |
2322 | |
2323 | // Dispatch small shapes to VNNI for better performance |
2324 | const auto is_real_3d = (jcp.ndims == 5 |
2325 | && (jcp.id > 1 || jcp.od > 1 || jcp.kd > 1 || jcp.dilate_d > 0)); |
2326 | const auto is_supported_small_ic = jcp.ic <= 4 && !is_real_3d; |
2327 | const auto is_small_shape = (jcp.od * jcp.oh * jcp.ow <= 4) && jcp.ic <= 512 |
2328 | && jcp.mb * jcp.ngroups * jcp.ic * jcp.oc <= static_cast<int32_t>( |
2329 | platform::get_per_core_cache_size(1)); |
2330 | const auto is_3d_small_ic = is_real_3d && jcp.ic * jcp.oc <= 32 |
2331 | && jcp.od >= 128 && jcp.oh >= 128 && jcp.ow >= 128; |
2332 | if ((is_small_shape || is_3d_small_ic) && !is_supported_small_ic) |
2333 | return status::unimplemented; |
2334 | |
2335 | const auto zp = attr.zero_points_; |
2336 | jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST); |
2337 | jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC); |
2338 | jcp.zp_src_is_common = zp.common( |
2339 | DNNL_ARG_SRC); // otherwise, it's per-channel (not supported) |
2340 | |
2341 | if (!IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common) |
2342 | || !IMPLICATION(jcp.dst_zero_point || jcp.src_zero_point, |
2343 | is_int8_convolution)) |
2344 | return status::unimplemented; |
2345 | |
2346 | // Calculate zero-point padding values outside of the main JIT-kernel |
2347 | // and store the results in an auxiliary buffer. |
2348 | jcp.req_zero_point_buffer = jcp.src_zero_point |
2349 | && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0 || jcp.t_pad > 0 |
2350 | || jcp.f_pad > 0 || jcp.back_pad > 0); |
2351 | |
2352 | format_tag_t dat_tag_ncsp = utils::pick(ndims - 3, format_tag::nCw16c, |
2353 | format_tag::nChw16c, format_tag::nCdhw16c); |
2354 | format_tag_t dat_tag_nspc = utils::pick( |
2355 | ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
2356 | // To toggle the default data layout for BF16 between nChw16c and nhwc, |
2357 | // swap the following two variable definitions. Current choice: nhwc. |
2358 | |
2359 | // Clang-tidy change - if it was intentional please revert it and |
2360 | // put `NOLINTNEXTLINE` to suppress the warning. |
2361 | format_tag_t dat_tag_opt = dat_tag_nspc; |
2362 | format_tag_t dat_tag_alt |
2363 | = is_bf16_convolution ? dat_tag_ncsp : dat_tag_nspc; |
2364 | |
2365 | if (src_d.format_kind() == format_kind::any) { |
2366 | CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt)); |
2367 | jcp.src_tag = dat_tag_opt; |
2368 | } else |
2369 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt); |
2370 | |
2371 | if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt)) |
2372 | return status::unimplemented; |
2373 | |
2374 | jcp.is_nspc = jcp.src_tag == dat_tag_nspc; |
2375 | assert(IMPLICATION(is_int8_convolution, jcp.is_nspc)); |
2376 | |
2377 | // TODO: remove all support for nChw16c from this implementation |
2378 | if (!jcp.is_nspc) return status::unimplemented; |
2379 | |
2380 | if (dst_d.format_kind() == format_kind::any) { |
2381 | CHECK(memory_desc_init_by_tag(dst_md, jcp.src_tag)); |
2382 | jcp.dst_tag = jcp.src_tag; |
2383 | } else |
2384 | jcp.dst_tag = dst_d.matches_one_of_tag(jcp.src_tag); |
2385 | |
2386 | if (jcp.dst_tag != jcp.src_tag) return status::unimplemented; |
2387 | |
2388 | if (jcp.with_bias && bias_d.format_kind() == format_kind::any) |
2389 | CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); |
2390 | |
2391 | jcp.nthr = nthreads; |
2392 | |
2393 | jcp.ic_block = 16; |
2394 | jcp.oc_block = 16; |
2395 | |
2396 | const auto ic_unrounded = jcp.ic; |
2397 | if (jcp.ngroups == 1) { |
2398 | jcp.oc = rnd_up(jcp.oc, jcp.oc_block); |
2399 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
2400 | } |
2401 | bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0; |
2402 | if (!args_ok) return status::unimplemented; |
2403 | |
2404 | const int vnni_width = is_bf16_convolution ? 2 : 4; |
2405 | jcp.ic_block_int = jcp.ic_block * vnni_width; // 32 for bf16, 64 for int8 |
2406 | |
2407 | // fallback to non-amx impl when accumulation is too small |
2408 | const dim_t total_k = jcp.ic_without_padding * jcp.kd * jcp.kh * jcp.kw; |
2409 | const bool is_tiny_k = total_k < jcp.ic_block_int / 2; |
2410 | if (is_tiny_k) return status::unimplemented; |
2411 | |
2412 | // small-ic parameters |
2413 | jcp.ic_block_int_np = jcp.is_nspc |
2414 | ? nstl::min(jcp.ic_block_int, jcp.ic_without_padding) |
2415 | : jcp.ic_block_int; |
2416 | bool is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int; |
2417 | |
2418 | // reduced lowering |
2419 | const bool is_trivial_3d = everyone_is(1, jcp.id, jcp.od, jcp.kd); |
2420 | jcp.is_relo = is_trivial_3d |
2421 | && is_small_ic |
2422 | // no trivial cases |
2423 | && 1 < jcp.kh * jcp.kw |
2424 | // reduction dimension size (heuristic) |
2425 | && IMPLICATION(is_int8_convolution, |
2426 | 12 * 3 * 3 < ic_unrounded * jcp.kh * jcp.kw) |
2427 | // required for use of VPERMB instruction in weights copy kernel |
2428 | && IMPLICATION(is_int8_convolution, |
2429 | cpu().has(Xbyak::util::Cpu::tAVX512_VBMI)) |
2430 | // no dilation or excessive stride along w-direction |
2431 | && everyone_is(0, jcp.dilate_h, jcp.dilate_w) |
2432 | // no dilation or excessive stride along h-direction |
2433 | && jcp.stride_h <= jcp.kh && jcp.stride_w <= jcp.kw; |
2434 | |
2435 | // Dispatch specific small ic shapes to VNNI for better performance |
2436 | const auto is_2d_small_ic = jcp.ndims == 4 |
2437 | && (ic_unrounded < 12 |
2438 | || (ic_unrounded >= 12 && ic_unrounded < 16 |
2439 | && jcp.ow * jcp.oh < 768 * 768)); |
2440 | if (is_int8_convolution && is_2d_small_ic && !jcp.is_relo) |
2441 | return status::unimplemented; |
2442 | |
2443 | jcp.nreduce = jcp.kh * jcp.kw * jcp.ic_block_int_np; |
2444 | |
2445 | if (!jcp.is_relo) { |
2446 | jcp.ic_block_int_np = is_bf16_convolution |
2447 | ? jcp.ic_block_int |
2448 | : rnd_up(jcp.ic_block_int_np, vnni_width); |
2449 | is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int; |
2450 | } |
2451 | // k-remainders |
2452 | jcp.kw_per_tile = is_small_ic && !jcp.is_relo && jcp.dilate_w == 0 |
2453 | && jcp.stride_w <= jcp.kw // TODO: relax this restriction |
2454 | && jcp.kw * jcp.ic_block_int_np <= jcp.ic_block_int |
2455 | ? jcp.kw |
2456 | : 1; |
2457 | jcp.is_pbuffer_strided = (1 == jcp.kw_per_tile); |
2458 | jcp.n_stride_sets |
2459 | = jcp.is_pbuffer_strided ? nstl::min(jcp.stride_w, jcp.kw) : 1; |
2460 | jcp.kw_step = jcp.is_pbuffer_strided ? jcp.stride_w : jcp.kw_per_tile; |
2461 | |
2462 | if (attr.set_default_formats(&dst_md) != status::success) |
2463 | return status::unimplemented; |
2464 | |
2465 | const auto &p = attr.post_ops_; |
2466 | |
2467 | const int sum_ind = p.find(primitive_kind::sum); |
2468 | jcp.with_sum = sum_ind != -1; |
2469 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
2470 | jcp.with_eltwise = eltwise_ind != -1; |
2471 | const int binary_ind = p.find(primitive_kind::binary); |
2472 | jcp.with_binary = binary_ind != -1; |
2473 | jcp.sum_dt = p.get_sum_dt(jcp.dst_dt); |
2474 | |
2475 | jcp.post_ops = p; |
2476 | |
2477 | using namespace injector; |
2478 | const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16); |
2479 | const bool sum_requires_scale_one = sum_at_pos_0_only; |
2480 | const bool sum_requires_zp_zero = sum_at_pos_0_only; |
2481 | const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, |
2482 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
2483 | sum_requires_zp_zero}); |
2484 | if (!post_ops_ok_) return status::unimplemented; |
2485 | |
2486 | auto set_or_check_wei_format = [&]() { |
2487 | using namespace format_tag; |
2488 | using namespace memory_extra_flags; |
2489 | format_tag_t wei_tag; |
2490 | wei_tag = jcp.is_relo ? pick(with_groups + 2 * (ndims - 3), Owi16o, |
2491 | gOwi16o, Owhi16o, gOwhi16o, Odwhi16o, gOdwhi16o) |
2492 | : is_bf16_convolution |
2493 | ? pick(with_groups + 2 * (ndims - 3), OIw16i16o2i, |
2494 | gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i, |
2495 | OIdhw16i16o2i, gOIdhw16i16o2i) |
2496 | : is_small_ic ? pick(with_groups + 2 * (ndims - 3), |
2497 | OwI16o4i, gOwI16o4i, OhwI16o4i, gOhwI16o4i, |
2498 | OdhwI16o4i, gOdhwI16o4i) |
2499 | : pick(with_groups + 2 * (ndims - 3), |
2500 | OIw16i16o4i, gOIw16i16o4i, |
2501 | OIhw16i16o4i, gOIhw16i16o4i, |
2502 | OIdhw16i16o4i, gOIdhw16i16o4i); |
2503 | |
2504 | memory_desc_t want_wei_md = weights_md; |
2505 | memory_desc_init_by_tag(want_wei_md, wei_tag); |
2506 | |
2507 | if (jcp.src_zero_point) { |
2508 | want_wei_md.extra.flags |= compensation_conv_asymmetric_src; |
2509 | want_wei_md.extra.asymm_compensation_mask = (1 << 0) |
2510 | + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); |
2511 | } |
2512 | if (weights_md.format_kind == format_kind::any) { |
2513 | weights_md = want_wei_md; |
2514 | return true; |
2515 | } |
2516 | return weights_md == want_wei_md; |
2517 | }; |
2518 | |
2519 | if (!set_or_check_wei_format()) return status::unimplemented; |
2520 | |
2521 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
2522 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
2523 | jcp.typesize_bia |
2524 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
2525 | jcp.typesize_acc = sizeof(int32_t); |
2526 | |
2527 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
2528 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
2529 | jcp.nb_ic_int = div_up(jcp.ic, jcp.ic_block_int); |
2530 | |
2531 | jcp.nb_oc_blocking_thr_chunk = 1; |
2532 | |
2533 | const int target_palette = amx::get_target_palette(); |
2534 | jcp.max_tiles = amx::get_max_tiles(target_palette); |
2535 | jcp.full_tile_width = amx::get_max_rows(target_palette); |
2536 | if (jcp.max_tiles != 8 || jcp.full_tile_width != 16) |
2537 | return status::unimplemented; |
2538 | |
2539 | // Pack n rows per tile, such that: |
2540 | // ow + (ow + gen_kw - 1) * (n - 1) <= jcp.full_tile_width |
2541 | auto calculate_tile_width = [&](int n) { |
2542 | assert(n > 0); |
2543 | return jcp.ow + (gen_kw + jcp.ow - 1) * (n - 1); |
2544 | }; |
2545 | const bool ok_to_pack_tile = !jcp.is_relo |
2546 | && (utils::everyone_is(1, jcp.kh, jcp.kw) |
2547 | || utils::everyone_is(1, jcp.stride_h, jcp.stride_w)); |
2548 | const int max_oh_per_tile |
2549 | = 1 + (jcp.full_tile_width - jcp.ow) / (jcp.ow + gen_kw - 1); |
2550 | jcp.oh_per_tile = ok_to_pack_tile |
2551 | ? nstl::min(jcp.oh, nstl::max(1, max_oh_per_tile)) |
2552 | : 1; |
2553 | jcp.tile_width = nstl::min<int>( |
2554 | jcp.full_tile_width, calculate_tile_width(jcp.oh_per_tile)); |
2555 | jcp.ow_blocks = utils::div_up(jcp.ow, jcp.tile_width); |
2556 | |
2557 | // Prefer to use a single tile width when possible |
2558 | // (eg ow28 => 2 tiles of 14 vs 1 of 16 and 1 of 12) |
2559 | if (jcp.oh_per_tile == 1 && jcp.ow % jcp.ow_blocks == 0) |
2560 | jcp.tile_width = jcp.ow / jcp.ow_blocks; |
2561 | jcp.tile_tail = jcp.oh_per_tile == 1 ? jcp.ow % jcp.tile_width : 0; |
2562 | |
2563 | jcp.nb_oc_blocking = (jcp.nb_oc % 2 == 0) ? 2 : 1; |
2564 | jcp.nb_ic_blocking = 1; |
2565 | jcp.nb_oh_blocking |
2566 | = utils::everyone_is(true, jcp.tile_tail == 0, |
2567 | // requirement for interleave stores |
2568 | IMPLICATION(jcp.ow_blocks > 1, jcp.oh % 2 == 0), |
2569 | // requirement for small spatial |
2570 | utils::div_up(jcp.oh, jcp.oh_per_tile) > 1, |
2571 | // choose maximal pbuffer overlap for reduced lowering |
2572 | !jcp.is_relo) |
2573 | ? 2 |
2574 | : 1; |
2575 | |
2576 | // TODO: tune oh blocking |
2577 | const int oh_blk_size_param = jcp.is_relo ? 1 : 10; |
2578 | const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile; |
2579 | const int oh_blk_size = rnd_up(oh_blk_size_param, oh_step_size); |
2580 | jcp.oh_blk_size = rnd_up(nstl::min(jcp.oh, oh_blk_size), oh_step_size); |
2581 | // Here ihp means the input buffer height including padding (ie the number |
2582 | // of input rows required for computation of jcp.oh_blk_size output rows. |
2583 | // If an input row doesn't participate in the computation of any output row, |
2584 | // it isn't copied to the buffer at all (eg jcp.stride_h > gen_kh). |
2585 | jcp.ihp = jcp.is_relo |
2586 | ? jcp.oh_blk_size |
2587 | : (jcp.oh_blk_size - 1) * nstl::min(jcp.stride_h, gen_kh) + gen_kh; |
2588 | |
2589 | // TODO: tune ow blocking |
2590 | const int ow_blocks_per_call = jcp.is_relo ? 10 : 2; |
2591 | jcp.ow_block = nstl::min(jcp.ow, jcp.tile_width * ow_blocks_per_call); |
2592 | jcp.nb_ow = utils::div_up(jcp.ow, jcp.ow_block); |
2593 | // iwp includes all width elements that are really used in calculation |
2594 | // including left and right zero padding |
2595 | const bool are_sets_interleaved |
2596 | = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1); |
2597 | jcp.iwp = are_sets_interleaved |
2598 | ? (jcp.ow_block - 1) * nstl::min(jcp.stride_w, jcp.kw) + gen_kw |
2599 | : jcp.ow_block * jcp.kw; |
2600 | |
2601 | // Number of ops per tile store |
2602 | int ops_tile_store = jcp.tile_width; |
2603 | // Number of ops per accumulation tile |
2604 | int avaliable_ops = jcp.is_relo |
2605 | ? utils::div_up(jcp.nreduce, jcp.ic_block_int) |
2606 | : jcp.nb_ic_int * jcp.kh * (jcp.kw / jcp.kw_per_tile); |
2607 | // Number of vectors to store per tile operation |
2608 | // NOTE: set to zero to turn off interleave store (mostly for debugging) |
2609 | jcp.per_one_pstore = utils::div_up(ops_tile_store, avaliable_ops); |
2610 | |
2611 | if (jcp.is_relo) { |
2612 | jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.ihp * jcp.iwp * jcp.kh |
2613 | * jcp.ic_block_int_np |
2614 | // pbuffer pointer shifts each oh step for reduced-lowering |
2615 | + (jcp.oh - 1) * jcp.stride_h * jcp.ic_block_int_np |
2616 | // extra $line due to pbuffer writing full Zmm |
2617 | + jcp.ic_block_int; |
2618 | } else { |
2619 | jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.kd |
2620 | * ((size_t)jcp.ihp * jcp.iwp * jcp.ic_block_int_np |
2621 | // extra $line due to pbuffer writing full Zmm |
2622 | + jcp.ic_block_int); |
2623 | } |
2624 | jcp.wei_buffer_size = (size_t)jcp.ngroups * jcp.nb_oc |
2625 | * rnd_up(jcp.kh * jcp.kw * jcp.ic * jcp.oc_block, 1024); |
2626 | jcp.wsp_buffer_size = (size_t)jcp.nb_oh_blocking * jcp.nb_oc_blocking |
2627 | * jcp.full_tile_width * jcp.oc_block; |
2628 | |
2629 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
2630 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
2631 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
2632 | const int wei_mask_per_oc = 1 << (int)with_groups; |
2633 | jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc; |
2634 | jcp.dst_scale = !dst_scales.has_default_values(); |
2635 | |
2636 | // only common src & dst scales are supported |
2637 | // only common and per-oc-channel weight scales are supported |
2638 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc) |
2639 | && everyone_is(src_scales.mask_, dst_scales.mask_, 0); |
2640 | if (!scales_ok) return status::unimplemented; |
2641 | |
2642 | // Note: currently unsupported, results in seg-fault |
2643 | const int l_pad_output = nstl::min(jcp.ow, div_up(jcp.l_pad, jcp.stride_w)); |
2644 | if (!jcp.is_relo && (l_pad_output > jcp.ow_block)) |
2645 | return status::unimplemented; |
2646 | |
2647 | // Relevant to 'zero_point padding buffer' (pbuff) jit kernel |
2648 | if (jcp.req_zero_point_buffer) { |
2649 | auto calculate_output_padding_dims = [=](int o_dim, int s_pad, |
2650 | int e_pad, |
2651 | int &s_pad_output, |
2652 | int &e_pad_output, |
2653 | bool &o_mid, int &o_pad, |
2654 | int stride, |
2655 | bool req_mid_area) { |
2656 | s_pad_output = nstl::min(o_dim, div_up(s_pad, stride)); |
2657 | e_pad_output = nstl::min(o_dim, div_up(e_pad, stride)); |
2658 | o_mid = (o_dim - s_pad_output - e_pad_output > 0) && req_mid_area; |
2659 | o_pad = nstl::min(o_dim, |
2660 | nstl::max(1, s_pad_output + e_pad_output + (int)o_mid)); |
2661 | }; |
2662 | |
2663 | const bool mid_w_area = (jcp.l_pad > 0 || jcp.r_pad > 0) |
2664 | && (jcp.t_pad > 0 || jcp.b_pad > 0 || jcp.f_pad > 0 |
2665 | || jcp.back_pad > 0); |
2666 | const bool mid_h_area = (jcp.t_pad > 0 || jcp.b_pad > 0) |
2667 | && (jcp.l_pad > 0 || jcp.r_pad > 0 || jcp.f_pad > 0 |
2668 | || jcp.back_pad > 0); |
2669 | const bool mid_d_area = (jcp.f_pad > 0 || jcp.back_pad > 0) |
2670 | && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0 |
2671 | || jcp.t_pad > 0); |
2672 | calculate_output_padding_dims(jcp.ow, jcp.l_pad, jcp.r_pad, |
2673 | jcp.l_pad_output, jcp.r_pad_output, jcp.ow_mid, jcp.ow_pad, |
2674 | jcp.stride_w, mid_w_area); |
2675 | calculate_output_padding_dims(jcp.oh, jcp.t_pad, jcp.b_pad, |
2676 | jcp.t_pad_output, jcp.b_pad_output, jcp.oh_mid, jcp.oh_pad, |
2677 | jcp.stride_h, mid_h_area); |
2678 | calculate_output_padding_dims(jcp.od, jcp.f_pad, jcp.back_pad, |
2679 | jcp.f_pad_output, jcp.back_pad_output, jcp.od_mid, jcp.od_pad, |
2680 | jcp.stride_d, mid_d_area); |
2681 | jcp.zp_pbuff_size |
2682 | = jcp.od_pad * jcp.oh_pad * jcp.ow_pad * jcp.oc * jcp.ngroups; |
2683 | |
2684 | // compute zero-point padding kernel outside of the main parallel |
2685 | // region when threads are more likely to parallelize work across mb |
2686 | // within the convolution compute block. |
2687 | jcp.zp_pbuff_outer_compute = jcp.mb > 1 || is_3d; |
2688 | |
2689 | const bool params_ok = ((jcp.ow_pad - (int)jcp.ow_mid) <= max_pad * 2); |
2690 | if (!params_ok) { return status::unimplemented; } |
2691 | } |
2692 | |
2693 | // Set default parameters for driver code, but mostly required for |
2694 | // 'zero_point padding buffer' (pbuff) accumulation over output tensor |
2695 | set_oh_blk_limits(jcp); |
2696 | set_ow_blk_limits(jcp); |
2697 | |
2698 | return status::success; |
2699 | } |
2700 | |
2701 | status_t jit_avx512_core_amx_fwd_kernel_t::init_scratchpad( |
2702 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, |
2703 | const primitive_attr_t &attr) { |
2704 | |
2705 | size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size; |
2706 | scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in); |
2707 | if (jcp.is_relo) { |
2708 | scratchpad.book( |
2709 | key_conv_amx_wei_buffer, jcp.wei_buffer_size, jcp.typesize_in); |
2710 | } |
2711 | |
2712 | size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size; |
2713 | scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc); |
2714 | if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) { |
2715 | assert(jcp.ngroups == 1); |
2716 | scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia); |
2717 | } |
2718 | scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline |
2719 | if (jcp.req_zero_point_buffer) { |
2720 | const int nthr = jcp.zp_pbuff_outer_compute ? 1 : jcp.nthr; |
2721 | scratchpad.book(key_conv_zero_point_pad, |
2722 | (size_t)nthr * jcp.zp_pbuff_size, sizeof(int32_t)); |
2723 | if (!jcp.zp_pbuff_outer_compute) { |
2724 | const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
2725 | scratchpad.book<bool>(key_conv_zero_point_flag, |
2726 | (size_t)jcp.nthr * oc_chunks * jcp.ngroups); |
2727 | } |
2728 | } |
2729 | |
2730 | book_precomputed_scales( |
2731 | scratchpad, attr.scales_, jcp.ngroups * jcp.oc_without_padding); |
2732 | |
2733 | // Keep scratchpad memory footprint under control |
2734 | const size_t L2_size_per_core = platform::get_per_core_cache_size(2); |
2735 | const size_t L3_size_per_core = platform::get_per_core_cache_size(3); |
2736 | const size_t max_scratchpad_size |
2737 | = jcp.nthr * (L2_size_per_core + L3_size_per_core); |
2738 | // TODO: tune this relationship as needed |
2739 | if (scratchpad.size() > max_scratchpad_size) return status::unimplemented; |
2740 | return status::success; |
2741 | } |
2742 | |
2743 | void jit_avx512_core_amx_bwd_data_copy_kernel_t::copy_row( |
2744 | const bool is_masked) { |
2745 | assert(jcp.is_nspc && "no support for nChw16c in this copy kernel" ); |
2746 | |
2747 | const bool is_bf16 = jcp.ddst_dt == data_type::bf16; |
2748 | const int inp_w_step |
2749 | = jcp.ngroups * jcp.oc_without_padding * jcp.typesize_in; |
2750 | const int inp_h_step = jcp.ow * inp_w_step; |
2751 | const int out_w_step = jcp.oc_block_int * jcp.typesize_in; |
2752 | const int out_h_step = jcp.owp * out_w_step; |
2753 | |
2754 | auto zero_it = [=](reg64_t tmp_out_ptr, int offset) { |
2755 | // no mask as output is a padded buffer |
2756 | if (is_bf16) |
2757 | vmovdqu16(ptr[tmp_out_ptr + offset], zmm_zero); |
2758 | else |
2759 | vmovdqu8(ptr[tmp_out_ptr + offset], zmm_zero); |
2760 | }; |
2761 | |
2762 | auto copy_it = [=](reg64_t tmp_inp_ptr, int inp_off, reg64_t tmp_out_ptr, |
2763 | int out_off) { |
2764 | Zmm zmm_load = is_masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp; |
2765 | Zmm zmm_stor = zmm_tmp; // no mask as output is padded buffer |
2766 | if (is_bf16) { |
2767 | vmovdqu16(zmm_load, ptr[tmp_inp_ptr + inp_off]); |
2768 | vmovdqu16(ptr[tmp_out_ptr + out_off], zmm_stor); |
2769 | } else { |
2770 | vmovdqu8(zmm_load, ptr[tmp_inp_ptr + inp_off]); |
2771 | vmovdqu8(ptr[tmp_out_ptr + out_off], zmm_stor); |
2772 | } |
2773 | }; |
2774 | |
2775 | { // Handle Top Overflow |
2776 | Label label_tov_loop, label_tov_skip; |
2777 | // number of zero-padded rows above src buffer to copy |
2778 | mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]); |
2779 | test(reg_tov, reg_tov); |
2780 | jz(label_tov_skip, T_NEAR); |
2781 | L(label_tov_loop); |
2782 | { |
2783 | for (int ow = 0; ow < jcp.owp; ow++) { |
2784 | const int offset = ow * out_w_step; |
2785 | zero_it(reg_ptr_aux_out, offset); |
2786 | } |
2787 | add(reg_ptr_aux_out, out_h_step); |
2788 | dec(reg_tov); |
2789 | jnz(label_tov_loop, T_NEAR); |
2790 | } |
2791 | L(label_tov_skip); |
2792 | } |
2793 | |
2794 | // Handle Middle Loop |
2795 | Label label_khp_loop, label_khp_skip; |
2796 | test(reg_khp, reg_khp); |
2797 | jz(label_khp_skip, T_NEAR); |
2798 | mov(reg_cnt_khp, reg_khp); |
2799 | L(label_khp_loop); |
2800 | { |
2801 | Label label_lov, label_lov_skip; |
2802 | Label label_kwp, label_kwp_skip; |
2803 | Label label_rov, label_rov_skip; |
2804 | test(reg_lov, reg_lov); |
2805 | jnz(label_lov, T_NEAR); |
2806 | test(reg_kwp, reg_kwp); |
2807 | jnz(label_kwp, T_NEAR); |
2808 | test(reg_rov, reg_rov); |
2809 | jnz(label_rov, T_NEAR); |
2810 | |
2811 | test(reg_lov, reg_lov); |
2812 | jz(label_lov_skip, T_NEAR); // not really needed, but just to be safe |
2813 | L(label_lov); // Handle Left Overflow |
2814 | { |
2815 | Label label_lov_loop; |
2816 | mov(reg_cnt_tmp, reg_lov); |
2817 | L(label_lov_loop); |
2818 | { |
2819 | zero_it(reg_ptr_aux_out, 0); |
2820 | add(reg_ptr_aux_out, out_w_step); |
2821 | dec(reg_cnt_tmp); |
2822 | jnz(label_lov_loop, T_NEAR); |
2823 | } |
2824 | } |
2825 | L(label_lov_skip); |
2826 | |
2827 | test(reg_kwp, reg_kwp); |
2828 | jz(label_kwp_skip, T_NEAR); |
2829 | L(label_kwp); // Handle Center Loop |
2830 | { |
2831 | Label label_kwp_loop; |
2832 | mov(reg_ptr_aux_inp_w, reg_ptr_aux_inp_h); |
2833 | mov(reg_cnt_tmp, reg_kwp); |
2834 | L(label_kwp_loop); |
2835 | { |
2836 | copy_it(reg_ptr_aux_inp_w, 0, reg_ptr_aux_out, 0); |
2837 | add(reg_ptr_aux_out, out_w_step); |
2838 | add(reg_ptr_aux_inp_w, inp_w_step); |
2839 | dec(reg_cnt_tmp); |
2840 | |
2841 | if (jcp.stride_w > 1) { |
2842 | jz(label_kwp_skip, T_NEAR); |
2843 | // Handle Dilation-by-Stride |
2844 | for (int sw = 0; sw < jcp.stride_w - 1; sw++) { |
2845 | const int offset = sw * out_w_step; |
2846 | zero_it(reg_ptr_aux_out, offset); |
2847 | } |
2848 | add(reg_ptr_aux_out, (jcp.stride_w - 1) * out_w_step); |
2849 | if (jcp.stride_w == 2) |
2850 | dec(reg_cnt_tmp); |
2851 | else |
2852 | sub(reg_cnt_tmp, jcp.stride_w - 1); |
2853 | jmp(label_kwp_loop, T_NEAR); |
2854 | } else { |
2855 | jnz(label_kwp_loop, T_NEAR); |
2856 | } |
2857 | } |
2858 | } |
2859 | L(label_kwp_skip); |
2860 | |
2861 | test(reg_rov, reg_rov); |
2862 | jz(label_rov_skip, T_NEAR); |
2863 | L(label_rov); // Handle Right Overflow |
2864 | { |
2865 | Label label_rov_loop; |
2866 | mov(reg_cnt_tmp, reg_rov); |
2867 | L(label_rov_loop); |
2868 | { |
2869 | zero_it(reg_ptr_aux_out, 0); |
2870 | add(reg_ptr_aux_out, out_w_step); |
2871 | dec(reg_cnt_tmp); |
2872 | jnz(label_rov_loop, T_NEAR); |
2873 | } |
2874 | } |
2875 | L(label_rov_skip); |
2876 | |
2877 | add(reg_ptr_aux_inp_h, inp_h_step); |
2878 | dec(reg_cnt_khp); |
2879 | |
2880 | if (jcp.stride_h > 1) { |
2881 | jz(label_khp_skip, T_NEAR); |
2882 | // Handle Dilation-by-Stride |
2883 | for (int sh = 0; sh < jcp.stride_h - 1; sh++) { |
2884 | for (int ow = 0; ow < jcp.owp; ow++) { |
2885 | const int offset = sh * out_h_step + ow * out_w_step; |
2886 | zero_it(reg_ptr_aux_out, offset); |
2887 | } |
2888 | } |
2889 | add(reg_ptr_aux_out, (jcp.stride_h - 1) * out_h_step); |
2890 | if (jcp.stride_h == 2) |
2891 | dec(reg_cnt_khp); |
2892 | else |
2893 | sub(reg_cnt_khp, jcp.stride_h - 1); |
2894 | jmp(label_khp_loop, T_NEAR); |
2895 | } else { |
2896 | jnz(label_khp_loop, T_NEAR); |
2897 | } |
2898 | } |
2899 | L(label_khp_skip); |
2900 | |
2901 | { // Handle Bottom Overflow |
2902 | Label label_bov_loop, label_bov_skip; |
2903 | |
2904 | // number of zero-padded rows below src buffer to copy |
2905 | mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]); |
2906 | test(reg_bov, reg_bov); |
2907 | jz(label_bov_skip, T_NEAR); |
2908 | L(label_bov_loop); |
2909 | { |
2910 | for (int ow = 0; ow < jcp.owp; ow++) { |
2911 | const int offset = ow * out_w_step; |
2912 | zero_it(reg_ptr_aux_out, offset); |
2913 | } |
2914 | add(reg_ptr_aux_out, out_h_step); |
2915 | dec(reg_bov); |
2916 | jnz(label_bov_loop, T_NEAR); |
2917 | } |
2918 | L(label_bov_skip); |
2919 | } |
2920 | } |
2921 | |
2922 | void jit_avx512_core_amx_bwd_data_copy_kernel_t::kd_loop(bool is_masked) { |
2923 | |
2924 | Xbyak::Label kd_label, no_kd_label; |
2925 | |
2926 | const bool is_3d = jcp.ndims == 5; |
2927 | |
2928 | mov(reg_ptr_aux_out, reg_ptr_out); |
2929 | mov(reg_ptr_aux_inp_h, reg_ptr_inp); |
2930 | |
2931 | if (is_3d) { |
2932 | mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]); |
2933 | cmp(reg_kd, 0); |
2934 | jle(no_kd_label, T_NEAR); |
2935 | L(kd_label); |
2936 | push(reg_ptr_aux_out); |
2937 | push(reg_ptr_aux_inp_h); |
2938 | } |
2939 | |
2940 | copy_row(is_masked); |
2941 | |
2942 | if (is_3d) { |
2943 | const size_t inp_d_offset = static_cast<size_t>(jcp.typesize_in) |
2944 | * (jcp.dilate_d + 1) * jcp.oh * jcp.ow * jcp.ngroups |
2945 | * jcp.oc_without_padding; |
2946 | const size_t out_d_offset = static_cast<size_t>(jcp.typesize_in) |
2947 | * jcp.ohp * jcp.owp * jcp.oc_block_int; |
2948 | pop(reg_ptr_aux_inp_h); |
2949 | pop(reg_ptr_aux_out); |
2950 | sub(reg_ptr_aux_inp_h, inp_d_offset); |
2951 | add(reg_ptr_aux_out, out_d_offset); |
2952 | |
2953 | dec(reg_kd); |
2954 | jnz(kd_label, T_NEAR); |
2955 | L(no_kd_label); |
2956 | } |
2957 | } |
2958 | |
2959 | void jit_avx512_core_amx_bwd_data_copy_kernel_t::generate() { |
2960 | |
2961 | const int inp_c_step = jcp.oc_block_int * jcp.typesize_in; |
2962 | const int out_c_step = jcp.kd * jcp.ohp * jcp.owp * inp_c_step; |
2963 | const int nb_oc_int_no_tail = jcp.oc_without_padding / jcp.oc_block_int; |
2964 | const int oc_block_int_tail = jcp.oc_without_padding % jcp.oc_block_int; |
2965 | |
2966 | preamble(); |
2967 | |
2968 | // pointer to 1st needed element in src buffer |
2969 | mov(reg_ptr_inp, ptr[param1 + GET_OFF(src)]); |
2970 | // pointer to 1st needed element in dst buffer |
2971 | mov(reg_ptr_out, ptr[param1 + GET_OFF(dst)]); |
2972 | |
2973 | // number of rows of src buffer to copy |
2974 | mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]); |
2975 | |
2976 | // number of columns of src buffer to copy |
2977 | mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]); |
2978 | // number of zero-padded columns before src buffer to copy |
2979 | mov(reg_lov, ptr[param1 + GET_OFF(l_overflow)]); |
2980 | // number of zero-padded columns before src buffer to copy |
2981 | mov(reg_rov, ptr[param1 + GET_OFF(r_overflow)]); |
2982 | |
2983 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
2984 | |
2985 | if (oc_block_int_tail > 0) { |
2986 | uint64_t mask = (UINT64_C(1) << oc_block_int_tail) - 1; |
2987 | mov(reg_tmp, mask); |
2988 | kmovq(ktail_mask, reg_tmp); |
2989 | } |
2990 | |
2991 | if (nb_oc_int_no_tail == 0) { |
2992 | kd_loop(true); // masked |
2993 | } else if (nb_oc_int_no_tail == 1) { |
2994 | kd_loop(false); // unmasked! |
2995 | if (oc_block_int_tail > 0) { |
2996 | add(reg_ptr_inp, inp_c_step); |
2997 | add(reg_ptr_out, out_c_step); |
2998 | kd_loop(true); // masked |
2999 | } |
3000 | } else if (nb_oc_int_no_tail > 1) { |
3001 | mov(reg_cnt_ocb, nb_oc_int_no_tail); |
3002 | Label label_ocb_loop; |
3003 | L(label_ocb_loop); |
3004 | { |
3005 | kd_loop(false); // unmasked! |
3006 | add(reg_ptr_inp, inp_c_step); |
3007 | add(reg_ptr_out, out_c_step); |
3008 | dec(reg_cnt_ocb); |
3009 | jnz(label_ocb_loop); |
3010 | } |
3011 | if (oc_block_int_tail > 0) kd_loop(true); // masked |
3012 | } |
3013 | |
3014 | postamble(); |
3015 | } |
3016 | |
3017 | // Tile register decomposition |
3018 | // { C_BASE = 0, I_BASE = 4, W_BASE = 6, } |
3019 | int jit_avx512_core_amx_bwd_data_kernel_t::get_out_tensor(int h, int i) const { |
3020 | const int C_BASE = 0; |
3021 | const int C_LAST = 4; |
3022 | assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles); |
3023 | MAYBE_UNUSED(C_LAST); |
3024 | const int tile = C_BASE + h * jcp.nb_ih_blocking + i; |
3025 | assert(C_BASE <= tile && tile < C_LAST); |
3026 | return tile; |
3027 | } |
3028 | int jit_avx512_core_amx_bwd_data_kernel_t::get_inp_tensor(int h) const { |
3029 | const int I_BASE = 4; |
3030 | const int I_LAST = 6; |
3031 | assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles); |
3032 | MAYBE_UNUSED(I_LAST); |
3033 | const int tile = I_BASE + h; |
3034 | assert(I_BASE <= tile && tile < I_LAST); |
3035 | return tile; |
3036 | } |
3037 | int jit_avx512_core_amx_bwd_data_kernel_t::get_wei_tensor(int i) const { |
3038 | const int W_BASE = 6; |
3039 | const int W_LAST = 8; |
3040 | assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles); |
3041 | MAYBE_UNUSED(W_LAST); |
3042 | const int tile = W_BASE + i; |
3043 | assert(W_BASE <= tile && tile < W_LAST); |
3044 | return tile; |
3045 | } |
3046 | |
3047 | // Strides, shifts and offsets |
3048 | // - inp is a padded buffer ~ [nb_oc_int][kd][ohp][owp]{32c,64c} |
3049 | // - weights is user buffer ~ OIdhw16o16i{2o,4o} |
3050 | // - output is tiled buffer ~ [NBIH][NBIC][tile_width][16c] |
3051 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_ocb_step() const { |
3052 | return (size_t)jcp.typesize_in * jcp.kd * jcp.ohp * jcp.owp |
3053 | * jcp.oc_block_int; |
3054 | } |
3055 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_shift() const { |
3056 | return (size_t)jcp.typesize_in * jcp.tile_width * jcp.oc_block_int; |
3057 | } |
3058 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_d_step() const { |
3059 | return static_cast<size_t>(jcp.typesize_in) * jcp.ohp * jcp.owp |
3060 | * jcp.oc_block_int; |
3061 | } |
3062 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_offset( |
3063 | int ihb, int kh, int kw) const { |
3064 | // calculate offset by src height dimension |
3065 | size_t sp_offset = (size_t)ihb * jcp.owp; |
3066 | // add offset by kernel height dimension |
3067 | sp_offset += (size_t)(jcp.kh - 1 - kh) * (jcp.dilate_h + 1) * jcp.owp; |
3068 | // add offset by kernel width dimension |
3069 | sp_offset += (size_t)(jcp.kw - 1 - kw) * (jcp.dilate_w + 1); |
3070 | return jcp.typesize_in * sp_offset * jcp.oc_block_int; |
3071 | } |
3072 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_kh_step() const { |
3073 | return (size_t)jcp.typesize_in * jcp.kw * jcp.oc_block_int * jcp.ic_block; |
3074 | } |
3075 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_ocb_step() const { |
3076 | const bool is_deconv = jcp.prop_kind != prop_kind::backward_data; |
3077 | return (size_t)jcp.typesize_in * (is_deconv ? 1 : jcp.nb_ic) * jcp.kd |
3078 | * jcp.kh * jcp.kw * jcp.oc_block_int * jcp.ic_block; |
3079 | } |
3080 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_offset( |
3081 | int icb, int kh, int kw) const { |
3082 | const bool is_deconv = jcp.prop_kind != prop_kind::backward_data; |
3083 | const size_t wei_kw_stride = jcp.oc_block_int * jcp.ic_block; |
3084 | const size_t wei_kh_stride = jcp.kw * wei_kw_stride; |
3085 | const size_t wei_kd_stride = jcp.kh * wei_kh_stride; |
3086 | const size_t wei_icb_stride |
3087 | = (is_deconv ? jcp.nb_oc_int : 1) * jcp.kd * wei_kd_stride; |
3088 | return jcp.typesize_in |
3089 | * (icb * wei_icb_stride + kh * wei_kh_stride + kw * wei_kw_stride); |
3090 | } |
3091 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_d_step() const { |
3092 | const size_t wei_kd_stride |
3093 | = jcp.kh * jcp.kw * jcp.oc_block_int * jcp.ic_block; |
3094 | // step through 'kd' weight elements by `stride_d` |
3095 | return static_cast<size_t>(jcp.typesize_in) * jcp.stride_d * wei_kd_stride; |
3096 | } |
3097 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_icb_offset( |
3098 | int ihb, int icb) const { |
3099 | size_t el_offset = jcp.is_nspc |
3100 | ? (size_t)icb * jcp.ic_block |
3101 | + (size_t)ihb * jcp.iw * jcp.ngroups |
3102 | * jcp.ic_without_padding |
3103 | : (size_t)icb * jcp.id * jcp.ih * jcp.iw * jcp.ic_block |
3104 | + (size_t)ihb * jcp.iw * jcp.ic_block; |
3105 | return (size_t)jcp.typesize_out * el_offset; |
3106 | } |
3107 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_row_offset( |
3108 | int ihb, int icb, int j) const { |
3109 | size_t offset_w = jcp.is_nspc ? (size_t)jcp.typesize_out * j * jcp.ngroups |
3110 | * jcp.ic_without_padding |
3111 | : (size_t)jcp.typesize_out * j * jcp.ic_block; |
3112 | return get_out_icb_offset(ihb, icb) + offset_w; |
3113 | } |
3114 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_shift(int width) const { |
3115 | return jcp.is_nspc ? (size_t)jcp.typesize_out * width * jcp.ngroups |
3116 | * jcp.ic_without_padding |
3117 | : (size_t)jcp.typesize_out * width * jcp.ic_block; |
3118 | } |
3119 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_icb_offset( |
3120 | int ihb, int icb) const { |
3121 | size_t el_offset = (size_t)icb * prv_width_ * jcp.ic_block |
3122 | + (size_t)ihb * jcp.nb_ic_blocking * jcp.full_tile_width |
3123 | * jcp.ic_block; |
3124 | return jcp.typesize_acc * el_offset; |
3125 | } |
3126 | size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_row_offset( |
3127 | int ihb, int icb, int j) const { |
3128 | return get_wsp_icb_offset(ihb, icb) |
3129 | + (size_t)jcp.typesize_acc * j * jcp.ic_block; |
3130 | } |
3131 | |
3132 | // Code generation |
3133 | void jit_avx512_core_amx_bwd_data_kernel_t::prepare_output() { |
3134 | for (int h = 0; h < jcp.nb_ih_blocking; h++) |
3135 | for (int i = 0; i < jcp.nb_ic_blocking; i++) |
3136 | tilezero(Tmm(get_out_tensor(h, i))); |
3137 | } |
3138 | |
3139 | void jit_avx512_core_amx_bwd_data_kernel_t::init_runtime_counters( |
3140 | bool start_with_last_tile_block) { |
3141 | prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0 |
3142 | ? jcp.tile_tail |
3143 | : jcp.tile_width; |
3144 | |
3145 | row_count_ = 0; |
3146 | is_store_done_ = false; |
3147 | is_buffer_empty_ = true; |
3148 | } |
3149 | |
3150 | bool jit_avx512_core_amx_bwd_data_kernel_t::maybe_eltwise(int position) { |
3151 | using namespace primitive_kind; |
3152 | const auto &p = attr_.post_ops_; |
3153 | |
3154 | if (position == 0) { |
3155 | /* eltwise before sum */ |
3156 | return p.contain(eltwise, 0); |
3157 | } else if (position == 1) { |
3158 | /* eltwise after sum */ |
3159 | return p.contain(sum, 0) && p.contain(eltwise, 1); |
3160 | } |
3161 | |
3162 | return false; |
3163 | } |
3164 | |
3165 | Ymm jit_avx512_core_amx_bwd_data_kernel_t::ymm_mask( |
3166 | const Ymm &ymm_in, bool mask_flag, bool store) { |
3167 | return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z) |
3168 | : ymm_in; |
3169 | } |
3170 | |
3171 | Zmm jit_avx512_core_amx_bwd_data_kernel_t::zmm_mask( |
3172 | const Zmm &zmm_in, bool mask_flag, bool store) { |
3173 | return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) |
3174 | : zmm_in; |
3175 | } |
3176 | |
3177 | void jit_avx512_core_amx_bwd_data_kernel_t::cvt2ps(data_type_t type_in, |
3178 | const Zmm &zmm_in, const Operand &op, bool mask_flag) { |
3179 | const Zmm zmm = zmm_mask(zmm_in, mask_flag); |
3180 | switch (type_in) { |
3181 | case data_type::f32: |
3182 | case data_type::s32: vmovups(zmm, op); break; |
3183 | case data_type::s8: vpmovsxbd(zmm, op); break; |
3184 | case data_type::u8: vpmovzxbd(zmm, op); break; |
3185 | default: assert(!"unsupported data type" ); |
3186 | } |
3187 | if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in); |
3188 | } |
3189 | |
3190 | void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_bf16( |
3191 | const Zmm &zmm_out, int icb, int h, int w) { |
3192 | const bool mask_flag = jcp.is_nspc && jcp.ic_without_padding != jcp.ic |
3193 | && icb == (jcp.nb_ic_blocking - 1); |
3194 | |
3195 | auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w)); |
3196 | |
3197 | const auto &p = attr_.post_ops_; |
3198 | |
3199 | const int sum_idx = p.find(primitive_kind::sum); |
3200 | if (sum_idx != -1) { |
3201 | if (jcp.dsrc_dt == data_type::bf16) { |
3202 | vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr); |
3203 | vpslld(zmm_prev_dst, zmm_prev_dst, 16); |
3204 | vaddps(zmm_out, zmm_prev_dst); |
3205 | } else { |
3206 | vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr); |
3207 | vaddps(zmm_out, zmm_prev_dst); |
3208 | } |
3209 | } |
3210 | if (jcp.with_bias) { |
3211 | int bias_offset = jcp.typesize_bia * icb * jcp.ic_block; |
3212 | auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
3213 | if (jcp.bia_dt == data_type::bf16) { |
3214 | vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr); |
3215 | vpslld(zmm_bias, zmm_bias, 16); |
3216 | vaddps(zmm_out, zmm_bias); |
3217 | } else |
3218 | vaddps(zmm_mask(zmm_out, mask_flag), bias_addr); |
3219 | } |
3220 | |
3221 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
3222 | if (eltwise_ind != -1) eltwise_injector_->compute_vector(zmm_out.getIdx()); |
3223 | |
3224 | if (jcp.dsrc_dt == data_type::bf16) { |
3225 | Ymm ymm_out = Ymm(zmm_out.getIdx()); |
3226 | vcvtneps2bf16(ymm_out, zmm_out); |
3227 | vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true)); |
3228 | } else { |
3229 | vmovups(addr, zmm_mask(zmm_out, mask_flag, true)); |
3230 | } |
3231 | } |
3232 | |
3233 | void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8( |
3234 | const Zmm &zmm_out, int icb, int h, int w) { |
3235 | const int nb_ic_block = jcp.nb_ic_blocking; |
3236 | const int ic_block = jcp.ic_block; |
3237 | const bool mask_flag = true && jcp.ic_without_padding != jcp.ic |
3238 | && icb == (nb_ic_block - 1); |
3239 | |
3240 | auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w)); |
3241 | |
3242 | const auto &p = attr_.post_ops_; |
3243 | const int sum_idx = p.find(primitive_kind::sum); |
3244 | const float *p_sum_scale = nullptr; |
3245 | const int32_t *p_sum_zp = nullptr; |
3246 | if (sum_idx != -1) { |
3247 | const auto &p_entry = p.entry_[sum_idx]; |
3248 | p_sum_scale = &p_entry.sum.scale; |
3249 | p_sum_zp = &p_entry.sum.zero_point; |
3250 | } |
3251 | |
3252 | if (p_sum_scale) { |
3253 | if (*p_sum_scale != 1.f) |
3254 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
3255 | if (*p_sum_zp != 0) |
3256 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
3257 | } |
3258 | |
3259 | int scale_offset = jcp.is_ic_scale * (sizeof(float) * icb * ic_block); |
3260 | if (jcp.with_bias) { |
3261 | int bias_offset = jcp.typesize_bia * icb * ic_block; |
3262 | auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
3263 | cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); |
3264 | } |
3265 | |
3266 | /* add bias to zmm_accum */ |
3267 | vcvtdq2ps(zmm_out, zmm_out); |
3268 | const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag); |
3269 | vmulps(zmm_out_msk, zmm_out, |
3270 | EVEX_compress_addr(reg_ptr_scales, scale_offset)); |
3271 | if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias); |
3272 | |
3273 | /* Do post-ops */ |
3274 | if (maybe_eltwise(0)) eltwise_injector_->compute_vector(zmm_out.getIdx()); |
3275 | if (p_sum_scale) { // post_op: sum |
3276 | cvt2ps(jcp.dsrc_dt, zmm_prev_dst, addr, mask_flag); |
3277 | if (*p_sum_zp != 0) { |
3278 | vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]); |
3279 | vsubps(zmm_prev_dst, zmm_sum_zp); |
3280 | } |
3281 | if (*p_sum_scale == 1.f) |
3282 | vaddps(zmm_out, zmm_prev_dst); |
3283 | else |
3284 | vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); |
3285 | } |
3286 | if (maybe_eltwise(1)) eltwise_injector_->compute_vector(zmm_out.getIdx()); |
3287 | |
3288 | if (jcp.dst_scale) { vmulps(zmm_out_msk, zmm_out, zmm_dst_scale); } |
3289 | |
3290 | // Properly saturate the accumulators for integer datatypes |
3291 | if (one_of(jcp.dsrc_dt, u8, s8, s32)) { |
3292 | init_saturate_f32( |
3293 | zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dsrc_dt); |
3294 | saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dsrc_dt); |
3295 | vcvtps2dq(zmm_out, zmm_out); |
3296 | } |
3297 | |
3298 | const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true); |
3299 | |
3300 | switch (jcp.dsrc_dt) { |
3301 | case data_type::f32: |
3302 | case data_type::s32: vmovups(addr, zmm_out_store); break; |
3303 | case data_type::s8: vpmovsdb(addr, zmm_out_store); break; |
3304 | case data_type::u8: vpmovusdb(addr, zmm_out_store); break; |
3305 | default: assert(!"unknown dst_dt" ); |
3306 | } |
3307 | } |
3308 | |
3309 | void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector( |
3310 | const Zmm &zmm_out, int icb, int h, int w) { |
3311 | /* |
3312 | Output: |
3313 | jcp.is_nspc !jcp.is_nspc |
3314 | ------------------------ --------------------- |
3315 | INT8: [N][D][H][W][NBIC][16IC] |
3316 | BF16: [N][D][H][W][NBIC][16IC] or [N][NBIC][D][H][W][16IC] |
3317 | */ |
3318 | if (jcp.ddst_dt == data_type::bf16) { |
3319 | store_output_vector_bf16(zmm_out, icb, h, w); |
3320 | } else { |
3321 | store_output_vector_int8(zmm_out, icb, h, w); |
3322 | } |
3323 | } |
3324 | |
3325 | void jit_avx512_core_amx_bwd_data_kernel_t::store_output( |
3326 | int width, bool do_store) { |
3327 | auto store_output_block = [=](int width, bool do_store, |
3328 | bool is_last_ih_blks) { |
3329 | // Calculate the number of ih blocks; it may differ on last call |
3330 | const int n_ih_blks = is_last_ih_blks ? jcp.ih % jcp.nb_ih_blocking |
3331 | : jcp.nb_ih_blocking; |
3332 | for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) { |
3333 | for (int ihb = 0; ihb < n_ih_blks; ihb++) { |
3334 | /* Formats: Workspace: [NBIH][NBIC][W][16OC] */ |
3335 | tilestored(ptr[reg_wsp_ptr + reg_wei_stride |
3336 | + get_wsp_icb_offset(ihb, icb)], |
3337 | Tmm(get_out_tensor(ihb, icb))); |
3338 | is_buffer_empty_ = false; |
3339 | is_store_done_ = false; |
3340 | for (int tw = 0; tw < width && do_store; tw++) { |
3341 | Zmm zmm_out = Zmm(tw); |
3342 | vmovups(zmm_out, |
3343 | ptr[reg_wsp_ptr |
3344 | + get_wsp_row_offset(ihb, icb, tw)]); |
3345 | store_output_vector(zmm_out, icb, ihb, tw); |
3346 | } |
3347 | } |
3348 | } |
3349 | }; |
3350 | |
3351 | // adjustment in case interleave store is turned off |
3352 | do_store = do_store || jcp.per_one_pstore == 0; |
3353 | if (jcp.ih % jcp.nb_ih_blocking == 0) { |
3354 | store_output_block(width, do_store, /* is_last_ih_blks = */ false); |
3355 | } else { |
3356 | Label label_full_store, label_done; |
3357 | cmp(reg_last_h, 0); |
3358 | jne(label_full_store, T_NEAR); |
3359 | store_output_block(width, do_store, /* is_last_ih_blks = */ true); |
3360 | jmp(label_done, T_NEAR); |
3361 | L(label_full_store); |
3362 | store_output_block(width, do_store, /* is_last_ih_blks = */ false); |
3363 | L(label_done); |
3364 | } |
3365 | if (do_store) add(reg_out_ptr, get_out_shift(width)); |
3366 | } |
3367 | |
3368 | void jit_avx512_core_amx_bwd_data_kernel_t::interleave_store(int width) { |
3369 | for (int c = 0; |
3370 | c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_; |
3371 | c++) { |
3372 | // row_count = ihb * ICB * TW + icb * TW + tw |
3373 | int tw = row_count_ % prv_width_; |
3374 | int icb = (row_count_ / prv_width_) % jcp.nb_ic_blocking; |
3375 | int ihb = (row_count_ / prv_width_) / jcp.nb_ic_blocking; |
3376 | |
3377 | Zmm zmm_out = Zmm(tw); |
3378 | vmovups(zmm_out, ptr[reg_wsp_ptr + get_wsp_row_offset(ihb, icb, tw)]); |
3379 | store_output_vector(zmm_out, icb, ihb, tw); |
3380 | row_count_++; |
3381 | |
3382 | if (row_count_ |
3383 | == prv_width_ * jcp.nb_ic_blocking * jcp.nb_ih_blocking) { |
3384 | add(reg_out_ptr, get_out_shift(prv_width_)); |
3385 | |
3386 | is_store_done_save_ = is_store_done_; |
3387 | prv_width_save_ = prv_width_; |
3388 | |
3389 | row_count_ = 0; |
3390 | is_store_done_ = true; |
3391 | prv_width_ = width; |
3392 | } |
3393 | } |
3394 | } |
3395 | |
3396 | void jit_avx512_core_amx_bwd_data_kernel_t::skipped_interleave_store() { |
3397 | |
3398 | if (is_store_done_save_ || is_buffer_empty_) return; |
3399 | |
3400 | const int store_count |
3401 | = prv_width_save_ * jcp.nb_ic_blocking * jcp.nb_ih_blocking; |
3402 | for (int row_count = 0; row_count < store_count; row_count++) { |
3403 | // row_count = ihb * ICB * TW + icb * TW + tw |
3404 | int tw = row_count % prv_width_save_; |
3405 | int icb = (row_count / prv_width_save_) % jcp.nb_ic_blocking; |
3406 | int ihb = (row_count / prv_width_save_) / jcp.nb_ic_blocking; |
3407 | |
3408 | Zmm zmm_out = Zmm(tw); |
3409 | vmovups(zmm_out, ptr[reg_wsp_ptr + get_wsp_row_offset(ihb, icb, tw)]); |
3410 | store_output_vector(zmm_out, icb, ihb, tw); |
3411 | } |
3412 | is_store_done_save_ = true; |
3413 | add(reg_out_ptr, get_out_shift(prv_width_save_)); |
3414 | } |
3415 | |
3416 | void jit_avx512_core_amx_bwd_data_kernel_t::compute_ocb_loop( |
3417 | int width, bool do_interleave_store) { |
3418 | |
3419 | auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) { |
3420 | switch (jcp.ddst_dt) { |
3421 | using namespace data_type; |
3422 | case bf16: tdpbf16ps(x1, x2, x3); break; |
3423 | case s8: tdpbssd(x1, x2, x3); break; |
3424 | case u8: tdpbusd(x1, x2, x3); break; |
3425 | default: assert(!"unsupported data type" ); |
3426 | } |
3427 | }; |
3428 | |
3429 | for (int ocb = 0; ocb < jcp.nb_oc_int; ocb++) { |
3430 | // reverse order through spatial components of weights so that |
3431 | // input buffer is accessed in a monotonically increasing fashion |
3432 | for (int kh = jcp.kh - 1; kh >= 0; kh--) { |
3433 | for (int kw = jcp.kw - 1; kw >= 0; kw--) { |
3434 | for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) { |
3435 | tileloadd(Tmm(get_inp_tensor(ihb)), |
3436 | ptr[reg_inp_ptr + get_inp_offset(ihb, kh, kw) |
3437 | + reg_inp_stride]); |
3438 | } |
3439 | for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) { |
3440 | tileloadd(Tmm(get_wei_tensor(icb)), |
3441 | ptr[reg_wei_ptr + get_wei_offset(icb, kh, kw) |
3442 | + reg_wei_stride]); |
3443 | for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) { |
3444 | tdpbxxd(Tmm(get_out_tensor(ihb, icb)), |
3445 | Tmm(get_inp_tensor(ihb)), |
3446 | Tmm(get_wei_tensor(icb))); |
3447 | if (do_interleave_store) interleave_store(width); |
3448 | } |
3449 | } |
3450 | } |
3451 | } |
3452 | add(reg_inp_ptr, get_inp_ocb_step()); |
3453 | add(reg_wei_ptr, get_wei_ocb_step()); |
3454 | } |
3455 | sub(reg_inp_ptr, get_inp_ocb_step() * jcp.nb_oc_int); |
3456 | sub(reg_wei_ptr, get_wei_ocb_step() * jcp.nb_oc_int); |
3457 | } |
3458 | |
3459 | void jit_avx512_core_amx_bwd_data_kernel_t::compute_kd_loop( |
3460 | int width, bool do_store, bool handle_skipped_stores) { |
3461 | |
3462 | Label skip_compute_kd_label, kd_loop_label, end_kd_compute_label; |
3463 | |
3464 | prepare_output(); |
3465 | |
3466 | if (jcp.ndims == 5) { |
3467 | push(reg_inp_ptr); |
3468 | push(reg_wei_ptr); |
3469 | |
3470 | mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]); |
3471 | cmp(reg_kd, 0); |
3472 | jle(skip_compute_kd_label, T_NEAR); |
3473 | } |
3474 | |
3475 | compute_ocb_loop(width, true); |
3476 | |
3477 | if (jcp.ndims == 5) { |
3478 | L(kd_loop_label); |
3479 | |
3480 | // for bwd_d, filter elements are stepped through in reverse, e.g.: |
3481 | // diff_dst: [0, 1, 2, 3, 4] diff_src: |
3482 | // wei: [2, 1, 0] [0] |
3483 | // [2, 1, 0] [1] |
3484 | // [2, 1, 0] [2] |
3485 | // which results in 'data_copy_kernel_t' copying the 'kd' dimension |
3486 | // corresponding to 'diff_dst' elements in reverse. The layout for the |
3487 | // 'kd' elements of 'inp_buff' to compute 'diff_src[0]' are: |
3488 | // |
3489 | // [kd]: |
3490 | // inp_buff: [nb_oc_int]{2,1,0}[ohp][owp]{32c,64c} |
3491 | // |
3492 | // then, |
3493 | // inp_buff: [nb_oc_int]{3,2,1}[ohp][owp]{32c,64c} |
3494 | // |
3495 | // and so on... |
3496 | // |
3497 | // hence, step through 'reg_inp_ptr' in ascending order: |
3498 | add(reg_inp_ptr, get_inp_d_step()); |
3499 | add(reg_wei_ptr, get_wei_d_step()); |
3500 | |
3501 | dec(reg_kd); |
3502 | jz(end_kd_compute_label, T_NEAR); |
3503 | |
3504 | // because 'kd' is dynamic, it may skip interleaved stores, so do not |
3505 | // unroll for more than one kd iteration. |
3506 | compute_ocb_loop(width, false); |
3507 | jmp(kd_loop_label, T_NEAR); |
3508 | |
3509 | L(skip_compute_kd_label); |
3510 | |
3511 | // 'kd_padding' may be '0' due to gaps in the filter / diff_dst |
3512 | // computation due to stride or dilation, so 'compute_ocb_loop()' may be |
3513 | // skipped. This results in skipping interleaved_stores() as well. So |
3514 | // call a special case of 'interleave_store()' for such cases. |
3515 | if (handle_skipped_stores) skipped_interleave_store(); |
3516 | |
3517 | L(end_kd_compute_label); |
3518 | |
3519 | pop(reg_wei_ptr); |
3520 | pop(reg_inp_ptr); |
3521 | } |
3522 | |
3523 | store_output(width, do_store); |
3524 | |
3525 | add(reg_inp_ptr, get_inp_shift()); |
3526 | } |
3527 | |
3528 | void jit_avx512_core_amx_bwd_data_kernel_t::compute_iw_loop() { |
3529 | auto compute_iw_loop_body = [=](bool last_iwb, int num_tile_blocks) { |
3530 | // check if there are '0' gaps in input stores due to dilation or stride |
3531 | bool handle_skipped_stores = gaps_in_store() && num_tile_blocks > 1; |
3532 | |
3533 | int gen_tile_tail = last_iwb && jcp.tile_tail > 0 ? jcp.tile_tail |
3534 | : jcp.tile_width; |
3535 | init_runtime_counters(last_iwb && num_tile_blocks == 1); |
3536 | for (int iwb = 0; iwb < num_tile_blocks - 1; iwb++) |
3537 | compute_kd_loop(jcp.tile_width, false, handle_skipped_stores); |
3538 | compute_kd_loop(gen_tile_tail, true, handle_skipped_stores); |
3539 | }; |
3540 | |
3541 | if (jcp.nb_iw == 1) { |
3542 | compute_iw_loop_body(true, jcp.iw_blocks); |
3543 | } else { |
3544 | Label label_done; |
3545 | int iw_blocks_per_call = div_up(jcp.iw_block, jcp.tile_width); |
3546 | int last_iwb_tile_blocks = jcp.iw_blocks % iw_blocks_per_call; |
3547 | if (last_iwb_tile_blocks == 0 && jcp.tile_tail > 0) |
3548 | last_iwb_tile_blocks = iw_blocks_per_call; |
3549 | if (last_iwb_tile_blocks > 0) { |
3550 | Label label_not_last_iwb; |
3551 | mov(reg_tmp, ptr[param1 + GET_OFF(iwb)]); |
3552 | cmp(reg_tmp, jcp.nb_iw - 1); |
3553 | jne(label_not_last_iwb, T_NEAR); |
3554 | |
3555 | compute_iw_loop_body(true, last_iwb_tile_blocks); |
3556 | |
3557 | jmp(label_done, T_NEAR); |
3558 | |
3559 | L(label_not_last_iwb); |
3560 | } |
3561 | compute_iw_loop_body(false, iw_blocks_per_call); |
3562 | |
3563 | L(label_done); |
3564 | } |
3565 | } |
3566 | |
3567 | void jit_avx512_core_amx_bwd_data_kernel_t::generate() { |
3568 | preamble(); |
3569 | |
3570 | mov(reg_inp_ptr, ptr[param1 + GET_OFF(dst)]); // padded buffer of diff_dst |
3571 | mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]); // weights |
3572 | mov(reg_out_ptr, ptr[param1 + GET_OFF(src)]); // diff_src |
3573 | mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]); |
3574 | |
3575 | if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
3576 | |
3577 | if (jcp.dst_scale) { |
3578 | mov(reg_ptr_dst_scales, ptr[param1 + GET_OFF(dst_scale)]); |
3579 | vmovups(zmm_dst_scale, EVEX_compress_addr(reg_ptr_dst_scales, 0)); |
3580 | } |
3581 | mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); |
3582 | |
3583 | mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]); |
3584 | |
3585 | const int inp_stride = jcp.oc_block_int * jcp.typesize_in; |
3586 | const int wei_stride = jcp.ic_block * jcp.typesize_acc; |
3587 | mov(reg_inp_stride, inp_stride); |
3588 | mov(reg_wei_stride, wei_stride); |
3589 | |
3590 | if (jcp.is_nspc && jcp.ic_without_padding != jcp.ic) { |
3591 | // Use mask 0xF by default for all output data and post-ops |
3592 | // loads / stores with block index |
3593 | // icb = icc * jcp.nb_ic_blocking + (jcp.nb_ic_blocking - 1) |
3594 | // TODO: use masked loads / stores for the last icc only |
3595 | int current_block_size = jcp.ic_block; |
3596 | int mask = (1 << current_block_size) - 1; |
3597 | Xbyak::Reg32 regw_tmp = reg_tmp.cvt32(); |
3598 | mov(regw_tmp, mask); |
3599 | kmovw(ktail_mask, regw_tmp); |
3600 | Xbyak::Label mask_is_set; |
3601 | mov(reg_ic_blocks, ptr[param1 + GET_OFF(ic_blocks)]); |
3602 | cmp(reg_ic_blocks, jcp.nb_ic - jcp.nb_ic_blocking); |
3603 | jne(mask_is_set, T_NEAR); |
3604 | // Reset the mask |
3605 | current_block_size = jcp.ic_without_padding % jcp.ic_block; |
3606 | mask = (1 << current_block_size) - 1; |
3607 | mov(regw_tmp, mask); |
3608 | kmovw(ktail_mask, regw_tmp); |
3609 | |
3610 | L(mask_is_set); |
3611 | } |
3612 | compute_iw_loop(); |
3613 | |
3614 | postamble(); |
3615 | |
3616 | if (jcp.with_eltwise) eltwise_injector_->prepare_table(); |
3617 | } |
3618 | |
3619 | bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok( |
3620 | const jit_conv_conf_t &jcp, primitive_attr_t &attr) { |
3621 | using namespace primitive_kind; |
3622 | const auto &p = attr.post_ops_; |
3623 | const bool is_bf16 = jcp.ddst_dt == data_type::bf16; |
3624 | |
3625 | auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; |
3626 | |
3627 | auto is_sum = [&](int idx) { |
3628 | if (is_bf16) |
3629 | return p.entry_[idx].is_sum(); |
3630 | else |
3631 | return p.contain(sum, idx); |
3632 | }; |
3633 | |
3634 | switch (p.len()) { |
3635 | case 0: return true; |
3636 | case 1: return is_eltwise(0) || is_sum(0); |
3637 | case 2: |
3638 | return (is_sum(0) && is_eltwise(1)) |
3639 | || (!is_bf16 && is_sum(1) && is_eltwise(0)); |
3640 | default: return false; |
3641 | } |
3642 | |
3643 | return false; |
3644 | } |
3645 | |
3646 | void jit_avx512_core_amx_bwd_data_kernel_t::tile_configure(char *tcfg_buff) { |
3647 | const int vnni_width = jcp.ddst_dt == data_type::bf16 ? 2 : 4; |
3648 | // Input tile dimensions |
3649 | const int a_col = jcp.oc_block_int; |
3650 | const int a_row = jcp.tile_width; |
3651 | // Weights tile dimensions |
3652 | const int b_col = jcp.ic_block * vnni_width; |
3653 | const int b_row = a_col / vnni_width; |
3654 | // Accumulator tile dimensions |
3655 | const int c_col = jcp.ic_block; |
3656 | const int c_row = a_row; |
3657 | |
3658 | for (size_t i = 0; i < 64; i++) |
3659 | tcfg_buff[i] = 0; |
3660 | |
3661 | // Weights (W_BASE) Tensor Tiles |
3662 | for (int i = 0; i < jcp.nb_ic_blocking; i++) |
3663 | tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i), |
3664 | b_row, b_col * jcp.typesize_in); |
3665 | |
3666 | // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles |
3667 | for (int h = 0; h < jcp.nb_ih_blocking; h++) { |
3668 | tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h), |
3669 | a_row, a_col * jcp.typesize_in); |
3670 | for (int i = 0; i < jcp.nb_ic_blocking; i++) |
3671 | tc_configure_tile((palette_config_t *)tcfg_buff, |
3672 | get_out_tensor(h, i), c_row, c_col * jcp.typesize_acc); |
3673 | } |
3674 | |
3675 | ((palette_config_t *)tcfg_buff)->palette_id = amx::get_target_palette(); |
3676 | } |
3677 | |
3678 | status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp, |
3679 | const convolution_desc_t &cd, memory_desc_t &diff_src_md, |
3680 | memory_desc_t &weights_md, memory_desc_t &diff_dst_md, |
3681 | memory_desc_t *bias_md, primitive_attr_t &attr, int nthreads) { |
3682 | using namespace prop_kind; |
3683 | |
3684 | const memory_desc_wrapper diff_src_d(&diff_src_md); |
3685 | const memory_desc_wrapper weights_d(&weights_md); |
3686 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
3687 | const memory_desc_wrapper bias_d(bias_md); |
3688 | |
3689 | const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; |
3690 | int ndims = diff_src_d.ndims(); |
3691 | bool is_1d = ndims == 3; |
3692 | bool is_3d = ndims == 5; |
3693 | |
3694 | using namespace data_type; |
3695 | const bool is_deconv = cd.prop_kind != prop_kind::backward_data; |
3696 | |
3697 | const bool is_bf16 = everyone_is(true, diff_dst_d.data_type() == bf16, |
3698 | weights_d.data_type() == bf16, |
3699 | one_of(diff_src_d.data_type(), bf16, f32)); |
3700 | const bool is_bf16_convolution = is_bf16 && !is_deconv; |
3701 | const bool is_bf16_deconvolution = is_bf16 && is_deconv; |
3702 | const bool is_int8_deconvolution = is_deconv |
3703 | && everyone_is(true, one_of(diff_dst_d.data_type(), s8, u8), |
3704 | weights_d.data_type() == s8, |
3705 | one_of(diff_src_d.data_type(), f32, s32, s8, u8)); |
3706 | |
3707 | bool supported |
3708 | = mayiuse(avx512_core_amx) && (is_bf16 || is_int8_deconvolution); |
3709 | if (!supported) return status::unimplemented; |
3710 | |
3711 | jcp = zero<decltype(jcp)>(); |
3712 | jcp.isa = avx512_core_amx; |
3713 | jcp.ndims = ndims; |
3714 | jcp.prop_kind = cd.prop_kind; |
3715 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
3716 | |
3717 | jcp.mb = diff_src_d.dims()[0]; |
3718 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
3719 | jcp.oc_without_padding = jcp.oc; |
3720 | jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; |
3721 | jcp.ic_without_padding = jcp.ic; |
3722 | jcp.id = is_3d ? diff_src_d.dims()[2] : 1; |
3723 | jcp.ih = !is_1d ? diff_src_d.dims()[ndims - 2] : 1; |
3724 | jcp.iw = diff_src_d.dims()[ndims - 1]; |
3725 | jcp.od = is_3d ? diff_dst_d.dims()[2] : 1; |
3726 | jcp.oh = !is_1d ? diff_dst_d.dims()[ndims - 2] : 1; |
3727 | jcp.ow = diff_dst_d.dims()[ndims - 1]; |
3728 | jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; |
3729 | jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1; |
3730 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
3731 | jcp.f_pad = is_3d ? cd.padding[0][0] : 0; |
3732 | jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0; |
3733 | jcp.l_pad = cd.padding[0][ndims - 3]; |
3734 | jcp.stride_d = is_3d ? cd.strides[0] : 1; |
3735 | jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1; |
3736 | jcp.stride_w = cd.strides[ndims - 3]; |
3737 | |
3738 | // No bias for bf16 case to simplify integration with ref_deconvolution |
3739 | jcp.with_bias = bias_md && !is_bf16_convolution |
3740 | && cd.bias_desc.format_kind != format_kind::undef; |
3741 | |
3742 | jcp.dilate_d = is_3d ? cd.dilates[ndims - 5] : 0; |
3743 | jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0; |
3744 | jcp.dilate_w = cd.dilates[ndims - 3]; |
3745 | |
3746 | if (jcp.dilate_d != 0 && jcp.stride_d != 1) return status::unimplemented; |
3747 | |
3748 | const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1; |
3749 | const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1; |
3750 | const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1; |
3751 | jcp.back_pad = calculate_end_padding( |
3752 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, gen_kd); |
3753 | jcp.b_pad = calculate_end_padding( |
3754 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh); |
3755 | jcp.r_pad = calculate_end_padding( |
3756 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw); |
3757 | if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh |
3758 | || jcp.b_pad >= gen_kh || jcp.f_pad >= gen_kd |
3759 | || jcp.back_pad >= gen_kd) |
3760 | return status::unimplemented; |
3761 | |
3762 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
3763 | if (is_deconv) { |
3764 | jcp.ddst_dt = cd.src_desc.data_type; |
3765 | jcp.dsrc_dt = cd.dst_desc.data_type; |
3766 | } else { |
3767 | jcp.ddst_dt = cd.diff_dst_desc.data_type; |
3768 | jcp.dsrc_dt = cd.diff_src_desc.data_type; |
3769 | } |
3770 | jcp.wei_dt = cd.weights_desc.data_type; |
3771 | |
3772 | jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); |
3773 | |
3774 | if (jcp.is_depthwise) |
3775 | return status::unimplemented; // TODO: add support of DW convolution |
3776 | |
3777 | format_tag_t dat_tag_ncsp = pick(ndims - 3, format_tag::nCw16c, |
3778 | format_tag::nChw16c, format_tag::nCdhw16c); |
3779 | format_tag_t dat_tag_nspc = pick( |
3780 | ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
3781 | // To toggle the default data layout for BF16 between nChw16c and nhwc, |
3782 | // swap the following two variable definitions. Current choice: nhwc. |
3783 | format_tag_t dat_tag_opt = dat_tag_nspc; |
3784 | format_tag_t dat_tag_alt = is_bf16 ? dat_tag_ncsp : dat_tag_nspc; |
3785 | |
3786 | if (diff_src_d.format_kind() == format_kind::any) { |
3787 | CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag_opt)); |
3788 | jcp.src_tag = dat_tag_opt; |
3789 | } else |
3790 | jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt); |
3791 | |
3792 | if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt)) |
3793 | return status::unimplemented; |
3794 | |
3795 | jcp.is_nspc = jcp.src_tag == dat_tag_nspc; |
3796 | assert(IMPLICATION(is_int8_deconvolution, jcp.is_nspc)); |
3797 | |
3798 | // TODO: remove all support for nChw16c from this implementation |
3799 | if (!jcp.is_nspc) return status::unimplemented; |
3800 | |
3801 | if (diff_dst_d.format_kind() == format_kind::any) { |
3802 | CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag)); |
3803 | jcp.dst_tag = jcp.src_tag; |
3804 | } else |
3805 | jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag); |
3806 | |
3807 | if (jcp.dst_tag != jcp.src_tag) return status::unimplemented; |
3808 | |
3809 | if (jcp.with_bias && bias_d.format_kind() == format_kind::any) |
3810 | CHECK(memory_desc_init_by_tag(*bias_md, format_tag::x)); |
3811 | |
3812 | jcp.nthr = nthreads; |
3813 | |
3814 | jcp.ic_block = 16; |
3815 | jcp.oc_block = 16; |
3816 | |
3817 | if (jcp.ngroups == 1) { |
3818 | jcp.oc = rnd_up(jcp.oc, jcp.oc_block); |
3819 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
3820 | } |
3821 | bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0; |
3822 | if (!args_ok) return status::unimplemented; |
3823 | |
3824 | const int vnni_width = is_bf16 ? 2 : 4; |
3825 | jcp.oc_block_int = jcp.oc_block * vnni_width; // 32 for bf16, 64 for int8 |
3826 | |
3827 | if (attr.set_default_formats(&diff_src_md) != status::success) |
3828 | return status::unimplemented; |
3829 | if (!post_ops_ok(jcp, attr)) return status::unimplemented; |
3830 | |
3831 | const auto &p = attr.post_ops_; |
3832 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
3833 | jcp.with_eltwise = eltwise_ind != -1; |
3834 | if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; |
3835 | |
3836 | auto set_or_check_wei_format = [&]() { |
3837 | using namespace format_tag; |
3838 | format_tag_t wei_tag; |
3839 | if (is_bf16_convolution) |
3840 | wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16o16i2o, |
3841 | gOIw16o16i2o, OIhw16o16i2o, gOIhw16o16i2o, OIdhw16o16i2o, |
3842 | gOIdhw16o16i2o); |
3843 | else if (is_bf16_deconvolution) |
3844 | wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o2i, |
3845 | gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i, OIdhw16i16o2i, |
3846 | gOIdhw16i16o2i); |
3847 | else if (is_int8_deconvolution) |
3848 | wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o4i, |
3849 | gOIw16i16o4i, OIhw16i16o4i, gOIhw16i16o4i, OIdhw16i16o4i, |
3850 | gOIdhw16i16o4i); |
3851 | else { |
3852 | assert(!"unsupported combination" ); |
3853 | return false; |
3854 | } |
3855 | |
3856 | memory_desc_t want_wei_md = weights_md; |
3857 | memory_desc_init_by_tag(want_wei_md, wei_tag); |
3858 | |
3859 | if (weights_md.format_kind == format_kind::any) { |
3860 | weights_md = want_wei_md; |
3861 | return true; |
3862 | } |
3863 | return weights_md == want_wei_md; |
3864 | }; |
3865 | |
3866 | if (!set_or_check_wei_format()) return status::unimplemented; |
3867 | |
3868 | jcp.typesize_in = types::data_type_size(diff_dst_d.data_type()); |
3869 | jcp.typesize_out = types::data_type_size(diff_src_d.data_type()); |
3870 | jcp.typesize_bia |
3871 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
3872 | jcp.typesize_acc = sizeof(int32_t); |
3873 | |
3874 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
3875 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
3876 | jcp.nb_oc_int = div_up(jcp.oc, jcp.oc_block_int); |
3877 | |
3878 | const int target_palette = amx::get_target_palette(); |
3879 | jcp.max_tiles = amx::get_max_tiles(target_palette); |
3880 | jcp.full_tile_width = amx::get_max_rows(target_palette); |
3881 | if (jcp.max_tiles != 8 || jcp.full_tile_width != 16) |
3882 | return status::unimplemented; |
3883 | |
3884 | jcp.tile_width = nstl::min(jcp.full_tile_width, jcp.iw); |
3885 | jcp.iw_blocks = div_up(jcp.iw, jcp.tile_width); |
3886 | |
3887 | // Prefer to use a single tile width when possible |
3888 | // (eg iw28 => 2 tiles of 14 vs 1 of 16 and 1 of 12) |
3889 | if (jcp.iw % jcp.iw_blocks == 0) jcp.tile_width = jcp.iw / jcp.iw_blocks; |
3890 | jcp.tile_tail = jcp.iw % jcp.tile_width; |
3891 | |
3892 | jcp.nb_ic_blocking = (jcp.nb_ic % 2 == 0) ? 2 : 1; |
3893 | jcp.nb_ih_blocking |
3894 | = everyone_is(true, jcp.ih > 1, |
3895 | // requirement for interleave stores |
3896 | IMPLICATION(jcp.iw_blocks > 1, jcp.ih % 2 == 0)) |
3897 | ? 2 |
3898 | : 1; |
3899 | |
3900 | // TODO: tune ih blocking |
3901 | const int ih_blk_size_tmp = 10; |
3902 | const int ih_step = jcp.nb_ih_blocking; |
3903 | jcp.ih_blk_size = rnd_up(nstl::min(jcp.ih, ih_blk_size_tmp), ih_step); |
3904 | // ohp includes all elements that are really used in calculation, |
3905 | // including zero-padded "dilate-by-strides" and top and bottom overflow |
3906 | jcp.ohp = jcp.ih_blk_size + gen_kh - 1; |
3907 | |
3908 | // TODO: tune iw blocking |
3909 | const int iw_blocks_per_call = 2; |
3910 | jcp.iw_block = jcp.tile_width * iw_blocks_per_call; |
3911 | jcp.nb_iw = div_up(jcp.iw, jcp.iw_block); |
3912 | // owp includes all elements that are really used in calculation, |
3913 | // including zero-padded "dilate-by-strides" and left and right overflow |
3914 | jcp.owp = jcp.iw_block + gen_kw - 1; |
3915 | |
3916 | // Number of ops per tile store |
3917 | int ops_tile_store = jcp.tile_width; |
3918 | // Number of ops per accumulation tile |
3919 | int avaliable_ops = jcp.nb_oc_int * jcp.kh * jcp.kw; |
3920 | // Number of vectors to store per tile operation |
3921 | // NOTE: set to zero to turn off interleave store (mostly for debugging) |
3922 | jcp.per_one_pstore = div_up(ops_tile_store, avaliable_ops); |
3923 | |
3924 | jcp.inp_buffer_size = static_cast<size_t>(jcp.nb_oc_int) * jcp.kd * jcp.ohp |
3925 | * jcp.owp * jcp.oc_block_int; |
3926 | jcp.wsp_buffer_size = (size_t)jcp.nb_ih_blocking * jcp.nb_ic_blocking |
3927 | * jcp.full_tile_width * jcp.ic_block; |
3928 | |
3929 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
3930 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
3931 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
3932 | const int wei_mask_per_ic = 1 << (int)with_groups; |
3933 | jcp.is_ic_scale = wei_scales.mask_ == wei_mask_per_ic; |
3934 | jcp.dst_scale = !dst_scales.has_default_values(); |
3935 | |
3936 | // only common src & dst scales are supported |
3937 | // only common and per-oc-channel weight scales are supported |
3938 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_ic) |
3939 | && everyone_is(src_scales.mask_, dst_scales.mask_, 0); |
3940 | if (!scales_ok) return status::unimplemented; |
3941 | |
3942 | return status::success; |
3943 | } |
3944 | |
3945 | void jit_avx512_core_amx_bwd_data_kernel_t::init_scratchpad( |
3946 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, |
3947 | const primitive_attr_t &attr) { |
3948 | |
3949 | size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size; |
3950 | scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in); |
3951 | size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size; |
3952 | scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc); |
3953 | if (jcp.with_bias && jcp.ic != jcp.ic_without_padding) { |
3954 | assert(jcp.ngroups == 1); |
3955 | scratchpad.book(key_conv_padded_bias, jcp.ic, jcp.typesize_bia); |
3956 | } |
3957 | scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline |
3958 | |
3959 | book_precomputed_scales( |
3960 | scratchpad, attr.scales_, jcp.ngroups * jcp.ic_without_padding); |
3961 | } |
3962 | |
3963 | const int jit_avx512_core_amx_bwd_weights_kernel_t::max_ur_w = 32; |
3964 | |
3965 | // Tile register decomposition |
3966 | // { C_BASE = 0, A_BASE = 4, B_BASE = 6, } |
3967 | int jit_avx512_core_amx_bwd_weights_kernel_t::get_wei_tensor( |
3968 | int ocb, int icb) const { |
3969 | const int C_BASE = 0; |
3970 | const int C_LAST = 4; |
3971 | assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles); |
3972 | MAYBE_UNUSED(C_LAST); |
3973 | const int tile = C_BASE + ocb * jcp.nb_oc_blocking + icb; |
3974 | assert(C_BASE <= tile && tile < C_LAST); |
3975 | return tile; |
3976 | } |
3977 | int jit_avx512_core_amx_bwd_weights_kernel_t::get_src_tensor(int icb) const { |
3978 | const int A_BASE = 4; |
3979 | const int A_LAST = 6; |
3980 | assert(0 <= A_BASE && A_BASE < A_LAST && A_LAST <= jcp.max_tiles); |
3981 | MAYBE_UNUSED(A_LAST); |
3982 | const int tile = A_BASE + icb; |
3983 | assert(A_BASE <= tile && tile < A_LAST); |
3984 | return tile; |
3985 | } |
3986 | int jit_avx512_core_amx_bwd_weights_kernel_t::get_ddst_tensor(int ocb) const { |
3987 | const int B_BASE = 6; |
3988 | const int B_LAST = 8; |
3989 | assert(0 <= B_BASE && B_BASE < B_LAST && B_LAST <= jcp.max_tiles); |
3990 | MAYBE_UNUSED(B_LAST); |
3991 | const int tile = B_BASE + ocb; |
3992 | assert(B_BASE <= tile && tile < B_LAST); |
3993 | return tile; |
3994 | } |
3995 | |
3996 | void jit_avx512_core_amx_bwd_weights_kernel_t::tile_configure(char *tcfg_buff) { |
3997 | // Input tile dimensions |
3998 | const int a_col = jcp.ur_w; |
3999 | const int a_row = jcp.ic_block; |
4000 | // Weights tile dimensions |
4001 | const int b_col = jcp.oc_block * 2; |
4002 | const int b_row = a_col / 2; |
4003 | // Accumulator tile dimensions |
4004 | const int c_col = jcp.oc_block; |
4005 | const int c_row = a_row; |
4006 | |
4007 | for (size_t i = 0; i < 64; i++) |
4008 | tcfg_buff[i] = 0; |
4009 | |
4010 | for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) |
4011 | tc_configure_tile((palette_config_t *)tcfg_buff, get_src_tensor(icb), |
4012 | a_row, a_col * jcp.typesize_in); |
4013 | |
4014 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) |
4015 | tc_configure_tile((palette_config_t *)tcfg_buff, get_ddst_tensor(ocb), |
4016 | b_row, b_col * jcp.typesize_in); |
4017 | |
4018 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) |
4019 | for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) |
4020 | tc_configure_tile((palette_config_t *)tcfg_buff, |
4021 | get_wei_tensor(ocb, icb), c_row, c_col * jcp.typesize_out); |
4022 | |
4023 | ((palette_config_t *)tcfg_buff)->palette_id = amx::get_target_palette(); |
4024 | } |
4025 | |
4026 | void jit_avx512_core_amx_bwd_weights_kernel_t::od_step_comeback_pointers() { |
4027 | Label kd_comeback_label; |
4028 | mov(kj, reg_kd_count); |
4029 | L(kd_comeback_label); |
4030 | { |
4031 | sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1))); |
4032 | sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw)); |
4033 | dec(kj); |
4034 | jnz(kd_comeback_label, T_NEAR); |
4035 | } |
4036 | } |
4037 | |
4038 | void jit_avx512_core_amx_bwd_weights_kernel_t::oh_step_comeback_pointers() { |
4039 | Label kh_comeback_label; |
4040 | mov(kj, reg_kh); |
4041 | L(kh_comeback_label); |
4042 | { |
4043 | sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1))); |
4044 | sub(reg_kernel, get_kernel_offset(0, jcp.kw)); |
4045 | dec(kj); |
4046 | jnz(kh_comeback_label, T_NEAR); |
4047 | } |
4048 | } |
4049 | |
4050 | void jit_avx512_core_amx_bwd_weights_kernel_t::compute_full_spat_loop( |
4051 | int nb_ic_blocking, int nb_oc_blocking) { |
4052 | // General code layout: |
4053 | // |
4054 | // Blocking over OH -- top level |
4055 | // (Reduces L2 pressure; not very useful right now) |
4056 | // Loop over all KHxKW kernel -- emit_kh_kw_loop() |
4057 | // Loop over OH block -- emit_h_loop() |
4058 | // Loop over OW blocks -- emit_fma_block() |
4059 | // (Supports both fully unrolled and partially unrolled |
4060 | // versions to reduce code size) |
4061 | // Loop over OW block -- emit_fma_step() |
4062 | |
4063 | auto src_row_size = get_src_offset(0, 0, 1); |
4064 | auto ddst_row_size = get_ddst_offset(0, 1); |
4065 | auto row_size = src_row_size + ddst_row_size; |
4066 | |
4067 | int h_block_size = jcp.oh; |
4068 | int h_last_block_size = h_block_size; |
4069 | int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad)); |
4070 | auto working_set_size = row_size * h_block_size; |
4071 | |
4072 | if (working_set_size > full_spat_max_working_set_size) { |
4073 | assert(full_spat_opt_working_set_size < full_spat_max_working_set_size); |
4074 | |
4075 | while (working_set_size > full_spat_opt_working_set_size |
4076 | && h_block_size >= min_h_block_size) { |
4077 | for (int i = 2; i <= h_block_size; i++) |
4078 | if (i == h_block_size) |
4079 | h_block_size = h_block_size / 2; |
4080 | else if (h_block_size % i == 0) { |
4081 | h_block_size = h_block_size / i; |
4082 | break; |
4083 | } |
4084 | working_set_size = row_size * h_block_size; |
4085 | } |
4086 | h_block_size = nstl::max(min_h_block_size, h_block_size); |
4087 | h_last_block_size = jcp.oh % h_block_size; |
4088 | if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size; |
4089 | } |
4090 | |
4091 | Opmask reg_h_block = k1; |
4092 | Reg64 reg_kh = rax; |
4093 | Reg64 reg_kw = rbx; |
4094 | Reg64 reg_tmp = abi_not_param1; |
4095 | Reg32 reg_tmp_w = reg_tmp.cvt32(); |
4096 | Reg64 reg_ohs = rdx; |
4097 | Reg64 reg_ihs = rsi; |
4098 | Reg64 reg_h = r8; |
4099 | Reg64 reg_j = r10; |
4100 | |
4101 | Reg64 reg_src = r13; |
4102 | Reg64 reg_ddst = r14; |
4103 | Reg64 reg_ker = r15; |
4104 | |
4105 | Reg64 reg_dense_stride = abi_param1; |
4106 | Reg64 reg_a_stride = reg_tmp; |
4107 | |
4108 | auto emit_block = [&]() { |
4109 | mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in); |
4110 | for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) { |
4111 | dim_t ur_w_src_offset = ur_w_b * get_src_offset(0, jcp.ur_w); |
4112 | dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w); |
4113 | |
4114 | for (int icb = 0; icb < nb_ic_blocking; icb++) { |
4115 | dim_t icb_offset = jcp.typesize_in * icb * jcp.tr_src_buf_size; |
4116 | tileloadd(Tmm(get_src_tensor(icb)), |
4117 | ptr[reg_src + reg_a_stride + icb_offset |
4118 | + ur_w_src_offset]); |
4119 | } |
4120 | for (int ocb = 0; ocb < nb_oc_blocking; ocb++) { |
4121 | tileloadd(Tmm(get_ddst_tensor(ocb)), |
4122 | ptr[reg_ddst + reg_dense_stride |
4123 | + jcp.typesize_in * ocb |
4124 | * jcp.tr_diff_dst_buf_size |
4125 | + ur_w_ddst_offset]); |
4126 | for (int icb = 0; icb < nb_ic_blocking; icb++) |
4127 | tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)), |
4128 | Tmm(get_src_tensor(icb)), |
4129 | Tmm(get_ddst_tensor(ocb))); |
4130 | } |
4131 | } |
4132 | }; |
4133 | |
4134 | auto emit_h_loop = [&]() { |
4135 | Label h_loop, skip_h_loop; |
4136 | mov(reg_j, 1); |
4137 | cmp(reg_j, reg_h); |
4138 | je(skip_h_loop, T_NEAR); |
4139 | L(h_loop); |
4140 | { |
4141 | emit_block(); |
4142 | |
4143 | add(reg_src, get_src_offset(0, 0, 1)); |
4144 | add(reg_ddst, get_ddst_offset(0, 1)); |
4145 | add(reg_j, 1); |
4146 | cmp(reg_j, reg_h); |
4147 | jb(h_loop); |
4148 | } |
4149 | L(skip_h_loop); |
4150 | |
4151 | emit_block(); |
4152 | }; |
4153 | |
4154 | auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) { |
4155 | xor_(reg_kh, reg_kh); |
4156 | Label kh_loop, kh_loop_end; |
4157 | |
4158 | int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size; |
4159 | // NB: this is correct because we only support t_pad = kh / 2 and thus |
4160 | // ih == oh |
4161 | int ih_block_size = oh_block_size |
4162 | + (!is_first_block + !is_last_block) * jcp.t_pad; |
4163 | |
4164 | L(kh_loop); |
4165 | { |
4166 | if (is_first_block) { |
4167 | xor_(reg_tmp, reg_tmp); |
4168 | mov(reg_ohs, jcp.t_pad); |
4169 | sub(reg_ohs, reg_kh); |
4170 | cmovb(reg_ohs, reg_tmp); |
4171 | |
4172 | mov(reg_ihs, reg_ohs); |
4173 | sub(reg_ihs, jcp.t_pad); |
4174 | add(reg_ihs, reg_kh); |
4175 | } else { |
4176 | xor_(reg_ohs, reg_ohs); |
4177 | mov(reg_ihs, reg_kh); |
4178 | } |
4179 | |
4180 | mov(reg_tmp, oh_block_size); |
4181 | sub(reg_tmp, reg_ohs); |
4182 | mov(reg_h, ih_block_size); |
4183 | sub(reg_h, reg_ihs); |
4184 | cmp(reg_tmp, reg_h); |
4185 | cmovb(reg_h, reg_tmp); |
4186 | |
4187 | Label kh_loop_work; |
4188 | cmp(reg_h, 0); |
4189 | jg(kh_loop_work, T_NEAR); |
4190 | |
4191 | // empty h loop for this jcp.kh: |
4192 | // - set the ddst to 0 if necessary |
4193 | // - move ker pt |
4194 | // - jump to the end |
4195 | sub(reg_h, 1); |
4196 | Label skip_ker_zeroing; |
4197 | |
4198 | // The reg_ker ptr has highest bit set if the ddst needs to be |
4199 | // zeroed. Those who have byte-aligned their data will suffer the |
4200 | // consequences :( |
4201 | // TODO: move the flag to a mask register? (Roma) |
4202 | test(reg_ker, 1); |
4203 | jz(skip_ker_zeroing, T_NEAR); |
4204 | |
4205 | Label zeroing_loop; |
4206 | vpxord(zmm0, zmm0, zmm0); |
4207 | and_(reg_ker, ~1); // temporarily clear the zeroing flag |
4208 | |
4209 | mov(reg_dense_stride, 64); |
4210 | tilezero(Tmm(get_wei_tensor(0, 0))); |
4211 | for (int kw = 0; kw < jcp.kw; kw++) { |
4212 | // dim_t kw_offset = kw * get_kernel_offset(jcp.ic_block, 0); |
4213 | for_(int ocb = 0; ocb < nb_oc_blocking; ocb++) |
4214 | for (int icb = 0; icb < nb_ic_blocking; icb++) |
4215 | tilestored( |
4216 | ptr[reg_ker + reg_dense_stride |
4217 | + get_full_kernel_offset(ocb, icb, 0, kw)], |
4218 | Tmm(get_wei_tensor(0, 0))); |
4219 | } |
4220 | // restore the zeroing flag (it will be cleared after the end of |
4221 | // emit_kh_kw_loop, but we may need it until then) |
4222 | or_(reg_ker, 1); |
4223 | jmp(kh_loop_end, T_NEAR); |
4224 | |
4225 | L(skip_ker_zeroing); |
4226 | add(reg_ker, get_kernel_offset(0, jcp.kw)); |
4227 | jmp(kh_loop_end, T_NEAR); |
4228 | |
4229 | L(kh_loop_work); |
4230 | |
4231 | mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1)); |
4232 | mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1)); |
4233 | |
4234 | add(reg_src, reg_ihs); |
4235 | add(reg_ddst, reg_ohs); |
4236 | |
4237 | Label kw_loop; |
4238 | xor_(reg_kw, reg_kw); |
4239 | |
4240 | mov(reg_dense_stride, 64); |
4241 | L(kw_loop); |
4242 | { |
4243 | Label do_zero, ker_init_done; |
4244 | test(reg_ker, 1); |
4245 | jnz(do_zero, T_NEAR); |
4246 | |
4247 | for_(int ocb = 0; ocb < nb_oc_blocking; ocb++) |
4248 | for (int icb = 0; icb < nb_ic_blocking; icb++) |
4249 | tileloadd(Tmm(get_wei_tensor(ocb, icb)), |
4250 | ptr[reg_ker + reg_dense_stride |
4251 | + get_full_kernel_offset(ocb, icb, 0, 0)]); |
4252 | jmp(ker_init_done); |
4253 | L(do_zero); |
4254 | for_(int ocb = 0; ocb < nb_oc_blocking; ocb++) |
4255 | for (int icb = 0; icb < nb_ic_blocking; icb++) |
4256 | tilezero(Tmm(get_wei_tensor(ocb, icb))); |
4257 | |
4258 | L(ker_init_done); |
4259 | |
4260 | mov(ptr[rsp + ddst_save_offset], reg_ddst); |
4261 | mov(ptr[rsp + src_save_offset], reg_src); |
4262 | |
4263 | lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]); |
4264 | emit_h_loop(); |
4265 | |
4266 | mov(reg_ddst, ptr[rsp + ddst_save_offset]); |
4267 | mov(reg_src, ptr[rsp + src_save_offset]); |
4268 | |
4269 | // The reg_ker ptr has highest bit set if the ddst needs to |
4270 | // be zeroed. Those who have byte-aligned their data will |
4271 | // suffer the consiquences :( |
4272 | mov(reg_tmp, reg_ker); |
4273 | and_(reg_ker, ~1); |
4274 | |
4275 | for_(int ocb = 0; ocb < nb_oc_blocking; ocb++) |
4276 | for (int icb = 0; icb < nb_ic_blocking; icb++) |
4277 | tilestored( |
4278 | ptr[reg_ker + reg_dense_stride |
4279 | + get_full_kernel_offset(ocb, icb, 0, 0)], |
4280 | Tmm(get_wei_tensor(ocb, icb))); |
4281 | |
4282 | mov(reg_ker, reg_tmp); |
4283 | add(reg_ker, get_kernel_offset(jcp.ic_block, 0)); |
4284 | add(reg_kw, 1); |
4285 | cmp(reg_kw, jcp.kw); |
4286 | jl(kw_loop); |
4287 | } |
4288 | |
4289 | sub(reg_src, reg_ihs); |
4290 | sub(reg_ddst, reg_ohs); |
4291 | |
4292 | L(kh_loop_end); |
4293 | add(reg_kh, 1); |
4294 | cmp(reg_kh, jcp.kh); |
4295 | jl(kh_loop); |
4296 | } |
4297 | }; |
4298 | |
4299 | mov(reg_src, ptr[param + GET_OFF(src)]); |
4300 | mov(reg_ddst, ptr[param + GET_OFF(dst)]); |
4301 | mov(reg_ker, ptr[param + GET_OFF(filt)]); |
4302 | mov(reg_tmp, ptr[param + GET_OFF(channel)]); |
4303 | or_(reg_ker, reg_tmp); |
4304 | |
4305 | bool single_kh_kw_loop = (h_last_block_size == jcp.oh); |
4306 | |
4307 | auto src_row_step = get_src_offset(0, 0, 1); |
4308 | auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad); |
4309 | auto ddst_block_step = get_ddst_offset(0, h_block_size); |
4310 | |
4311 | emit_kh_kw_loop(true, single_kh_kw_loop); |
4312 | |
4313 | if (!single_kh_kw_loop) { |
4314 | auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh); |
4315 | sub(reg_ker, ker_reset_offset); |
4316 | and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates |
4317 | |
4318 | add(reg_src, first_src_block_step); |
4319 | add(reg_ddst, ddst_block_step); |
4320 | |
4321 | int num_innermost_iters |
4322 | = (jcp.oh - h_last_block_size) / h_block_size - 1; |
4323 | if (num_innermost_iters > 0) { |
4324 | Label h_block_loop; |
4325 | |
4326 | mov(reg_tmp_w, num_innermost_iters); |
4327 | kmovw(reg_h_block, reg_tmp_w); |
4328 | L(h_block_loop); |
4329 | { |
4330 | emit_kh_kw_loop(false, false); |
4331 | sub(reg_ker, ker_reset_offset); |
4332 | add(reg_src, src_row_step * h_block_size); |
4333 | add(reg_ddst, ddst_block_step); |
4334 | |
4335 | kmovw(reg_tmp_w, reg_h_block); |
4336 | sub(reg_tmp_w, 1); |
4337 | kmovw(reg_h_block, reg_tmp_w); |
4338 | jnz(h_block_loop); |
4339 | } |
4340 | } |
4341 | |
4342 | emit_kh_kw_loop(false, true); |
4343 | } |
4344 | } |
4345 | |
4346 | void jit_avx512_core_amx_bwd_weights_kernel_t::compute_ic_loop( |
4347 | int ic_block, int nb_ic_blocking, int nb_oc_blocking) { |
4348 | assert(jcp.ur_w % 2 == 0); |
4349 | const int str_w = jcp.stride_w; |
4350 | assert(jcp.tr_iw % str_w == 0); |
4351 | const int src_stride_w_shift = jcp.tr_iw / str_w; |
4352 | |
4353 | mov(reg_b_stride, 64); |
4354 | mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in); |
4355 | |
4356 | for (int s = 0; s < str_w; s++) { |
4357 | for (int i_kw = s; i_kw < jcp.kw; i_kw += str_w) { |
4358 | |
4359 | for (int ocb = 0; ocb < nb_oc_blocking; ocb++) |
4360 | for (int icb = 0; icb < nb_ic_blocking; icb++) |
4361 | tileloadd(Tmm(get_wei_tensor(ocb, icb)), |
4362 | ptr[reg_kernel + reg_b_stride |
4363 | + get_full_kernel_offset( |
4364 | ocb, icb, 0, i_kw)]); |
4365 | |
4366 | int src_offset_l = (i_kw * (jcp.dilate_w + 1)) / str_w |
4367 | + s * src_stride_w_shift; |
4368 | |
4369 | for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) { |
4370 | dim_t ur_w_src_offset = ur_w_b |
4371 | * get_src_offset(0, filter_w_to_src(0, jcp.ur_w, 0)); |
4372 | dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w); |
4373 | for (int icb = 0; icb < nb_ic_blocking; icb++) { |
4374 | dim_t icb_offset = icb * jcp.tr_src_buf_size; |
4375 | tileloadd(Tmm(get_src_tensor(icb)), |
4376 | ptr[reg_src |
4377 | + jcp.typesize_in |
4378 | * (src_offset_l + icb_offset) |
4379 | + ur_w_src_offset + reg_a_stride]); |
4380 | } |
4381 | for (int ocb = 0; ocb < nb_oc_blocking; ocb++) { |
4382 | tileloadd(Tmm(get_ddst_tensor(ocb)), |
4383 | ptr[reg_ddst |
4384 | + jcp.typesize_in * ocb |
4385 | * jcp.tr_diff_dst_buf_size |
4386 | + ur_w_ddst_offset + reg_b_stride]); |
4387 | for (int icb = 0; icb < nb_ic_blocking; icb++) |
4388 | tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)), |
4389 | Tmm(get_src_tensor(icb)), |
4390 | Tmm(get_ddst_tensor(ocb))); |
4391 | } |
4392 | } |
4393 | |
4394 | for (int ocb = 0; ocb < nb_oc_blocking; ocb++) |
4395 | for (int icb = 0; icb < nb_ic_blocking; icb++) |
4396 | tilestored(ptr[reg_kernel + reg_b_stride |
4397 | + get_full_kernel_offset( |
4398 | ocb, icb, 0, i_kw)], |
4399 | Tmm(get_wei_tensor(ocb, icb))); |
4400 | } |
4401 | } |
4402 | safe_add(reg_src, get_src_offset(ic_block, 0), reg_long_offt); |
4403 | add(reg_kernel, get_kernel_offset(ic_block, 0)); |
4404 | } |
4405 | |
4406 | void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_init(int ocb) { |
4407 | auto reg_unit_val = reg_tmp.cvt16(); |
4408 | mov(reg_unit_val, 0x3f80); // bf16 value of 1. |
4409 | vpbroadcastw(vreg_bias_unit, reg_unit_val); |
4410 | |
4411 | mov(reg_tmp, ptr[param + GET_OFF(bias)]); |
4412 | vmovups(vreg_bias_acc, ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block]); |
4413 | } |
4414 | |
4415 | void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_row( |
4416 | bool is_partial, int ocb) { |
4417 | if (!jcp.with_bias) return; |
4418 | mov(reg_tmp, ptr[param + GET_OFF(flags)]); |
4419 | Label skip_label; |
4420 | test(reg_tmp, FLAG_IC_FIRST); |
4421 | jz(skip_label, T_NEAR); |
4422 | |
4423 | if (is_partial) { compute_diff_bias_init(ocb); } |
4424 | auto compute_step = [&]() { |
4425 | vmovups(vreg_bias_ddst, ptr[reg_ddst]); |
4426 | vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit); |
4427 | }; |
4428 | |
4429 | Label ow_loop, ow_tail; |
4430 | int niters = jcp.tr_ow / 2; |
4431 | if (niters > 0) { |
4432 | mov(reg_tmp, jcp.tr_ow / 2); |
4433 | L(ow_loop); |
4434 | compute_step(); |
4435 | add(reg_ddst, get_ddst_offset(2)); |
4436 | sub(reg_tmp, 1); |
4437 | jnz(ow_loop, T_NEAR); |
4438 | } |
4439 | if (jcp.tr_ow % 2) compute_step(); |
4440 | |
4441 | if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters)); |
4442 | |
4443 | if (is_partial) { |
4444 | mov(reg_tmp, ptr[param + GET_OFF(bias)]); |
4445 | vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block], |
4446 | vreg_bias_acc); |
4447 | } |
4448 | |
4449 | L(skip_label); |
4450 | } |
4451 | |
4452 | void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_compute_diff_bias( |
4453 | int nb_oc_blocking) { |
4454 | // In harness_3d_reduction case calculation of diff_bias is called |
4455 | // for every ow row separately to be aligned with od loop in |
4456 | // compute_od_loop_common() |
4457 | if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return; |
4458 | mov(reg_tmp, ptr[param + GET_OFF(flags)]); |
4459 | |
4460 | Label skip_label; |
4461 | test(reg_tmp, FLAG_IC_FIRST); |
4462 | jz(skip_label, T_NEAR); |
4463 | |
4464 | for (int ocb = 0; ocb < nb_oc_blocking; ocb++) { |
4465 | Label bias_loop, skip_label_local; |
4466 | |
4467 | mov(reg_ddst, ptr[param + GET_OFF(dst)]); |
4468 | add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size); |
4469 | |
4470 | switch (jcp.harness) { |
4471 | case harness_2d_reduction: |
4472 | mov(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
4473 | sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]); |
4474 | break; |
4475 | case harness_mb_reduction: |
4476 | case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break; |
4477 | case harness_3d_reduction: |
4478 | default: assert(!"Invalid harness type" ); |
4479 | } |
4480 | |
4481 | cmp(reg_oj, 0); |
4482 | jle(skip_label_local, T_NEAR); // nothing to do |
4483 | |
4484 | compute_diff_bias_init(ocb); |
4485 | L(bias_loop); |
4486 | { |
4487 | compute_diff_bias_row(false, ocb); |
4488 | add(reg_ddst, get_ddst_offset(0, 1)); |
4489 | |
4490 | sub(reg_oj, 1); |
4491 | jnz(bias_loop, T_NEAR); |
4492 | } |
4493 | |
4494 | mov(reg_tmp, ptr[param + GET_OFF(bias)]); |
4495 | vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block], |
4496 | vreg_bias_acc); |
4497 | |
4498 | L(skip_label_local); |
4499 | } |
4500 | // restore reg_ddst value |
4501 | mov(reg_ddst, ptr[param + GET_OFF(dst)]); |
4502 | |
4503 | L(skip_label); |
4504 | } |
4505 | |
4506 | void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_step_common( |
4507 | int nb_ic_blocking, int nb_oc_blocking) { |
4508 | Label kh_label, ic_block_label, ow_block_label, kd_label; |
4509 | |
4510 | int ic_block = jcp.ic_block; |
4511 | int ic_tail = jcp.ic_tail; |
4512 | |
4513 | auto ic_loop = [&](int nb_ic_blocking, int nb_oc_blocking) { |
4514 | Label ic_tail_label, ic_loop_done_label; |
4515 | |
4516 | if (ic_tail) { |
4517 | mov(reg_icb, ptr[param + GET_OFF(reduce_work)]); |
4518 | cmp(reg_icb, jcp.ic_tail); |
4519 | jne(ic_tail_label, T_NEAR); |
4520 | |
4521 | compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking); |
4522 | jmp(ic_loop_done_label, T_NEAR); |
4523 | |
4524 | L(ic_tail_label); |
4525 | compute_ic_loop(ic_tail, nb_ic_blocking, nb_oc_blocking); |
4526 | add(reg_kernel, get_kernel_offset(jcp.ic_block - ic_tail, 0)); |
4527 | safe_add(reg_src, |
4528 | get_src_offset(0, 0, filter_h_to_src(1)) |
4529 | - get_src_offset(ic_tail, 0), |
4530 | reg_long_offt); |
4531 | L(ic_loop_done_label); |
4532 | } else { |
4533 | compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking); |
4534 | } |
4535 | }; |
4536 | |
4537 | if (jcp.ndims == 5) { |
4538 | /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of |
4539 | * 'movs' must be guaranteed. */ |
4540 | mov(ki, reg_kd_count); |
4541 | mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count); |
4542 | mov(aux_reg_src, reg_src); |
4543 | mov(aux_reg_kernel, reg_kernel); |
4544 | |
4545 | L(kd_label); |
4546 | mov(reg_src, aux_reg_src); |
4547 | mov(reg_kernel, aux_reg_kernel); |
4548 | } |
4549 | |
4550 | mov(kj, reg_kh); |
4551 | L(kh_label); |
4552 | { |
4553 | ic_loop(nb_ic_blocking, nb_oc_blocking); |
4554 | |
4555 | if (jcp.dilate_h > 0) { |
4556 | add(reg_src, get_src_offset(0, 0, jcp.dilate_h)); |
4557 | } |
4558 | // substract pointer shift made within ic block loop |
4559 | // and move to next kh index |
4560 | add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw)); |
4561 | dec(kj); |
4562 | cmp(kj, 0); |
4563 | jg(kh_label, T_NEAR); |
4564 | } |
4565 | if (jcp.ndims == 5) { |
4566 | add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1))); |
4567 | add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw)); |
4568 | dec(ki); |
4569 | cmp(ki, 0); |
4570 | jg(kd_label, T_NEAR); |
4571 | } |
4572 | // In harness_3d_reduction case calculation of diff_bias is called |
4573 | // for every ow row separately to be aligned with od loop in |
4574 | // compute_od_loop_common() |
4575 | if (jcp.harness == harness_3d_reduction) { |
4576 | auto reg_save_ddst = reg_a_stride; |
4577 | mov(reg_save_ddst, reg_ddst); |
4578 | for (int ocb = 0; ocb < nb_oc_blocking; ocb++) { |
4579 | safe_add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size, |
4580 | reg_long_offt); |
4581 | compute_diff_bias_row(true, ocb); |
4582 | } |
4583 | mov(reg_ddst, reg_save_ddst); |
4584 | } |
4585 | |
4586 | if (jcp.ndims == 5) { |
4587 | mov(reg_src, aux_reg_src); |
4588 | mov(reg_kernel, aux_reg_kernel); |
4589 | mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset)); |
4590 | od_step_comeback_pointers(); |
4591 | } else { |
4592 | oh_step_comeback_pointers(); |
4593 | } |
4594 | } |
4595 | |
4596 | void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_zero_kernel( |
4597 | int nb_ic_blocking, int nb_oc_blocking) { |
4598 | if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return; |
4599 | Label skip_zeroing, zeroing_loop; |
4600 | |
4601 | mov(reg_tmp, ptr[param + GET_OFF(channel)]); |
4602 | cmp(reg_tmp, 0); |
4603 | jz(skip_zeroing, T_NEAR); |
4604 | |
4605 | Zmm zero = Zmm(0); |
4606 | vpxord(zero, zero, zero); |
4607 | if (jcp.with_bias) { |
4608 | Label skip_bias_zeroing; |
4609 | mov(reg_tmp, ptr[param + GET_OFF(flags)]); |
4610 | test(reg_tmp, FLAG_IC_FIRST); |
4611 | jz(skip_bias_zeroing, T_NEAR); |
4612 | for (int ocb = 0; ocb < nb_oc_blocking; ocb++) { |
4613 | mov(reg_tmp, ptr[param + GET_OFF(bias)]); |
4614 | vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block], zero); |
4615 | } |
4616 | L(skip_bias_zeroing); |
4617 | if (jcp.harness == harness_compute_full_spatial) |
4618 | jmp(skip_zeroing, T_NEAR); |
4619 | } |
4620 | |
4621 | mov(reg_b_stride, 64); |
4622 | tilezero(Tmm(get_wei_tensor(0, 0))); |
4623 | for (dim_t shift = 0; |
4624 | shift < get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd); |
4625 | shift += get_kernel_offset(jcp.ic_block, 0)) { |
4626 | for_(int icb = 0; icb < nb_ic_blocking; icb++) |
4627 | for (int ocb = 0; ocb < nb_oc_blocking; ocb++) { |
4628 | tilestored( |
4629 | ptr[reg_kernel + reg_b_stride |
4630 | + get_full_kernel_offset(ocb, icb, 0, 0) + shift], |
4631 | Tmm(get_wei_tensor(0, 0))); |
4632 | } |
4633 | } |
4634 | L(skip_zeroing); |
4635 | } |
4636 | |
4637 | void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_loop_common( |
4638 | int nb_ic_blocking, int nb_oc_blocking, bool is_partial) { |
4639 | int b_pad = jcp.b_pad; |
4640 | int t_pad = jcp.t_pad; |
4641 | |
4642 | bool is_dilated = jcp.dilate_h != 0; |
4643 | int dilate_h = jcp.dilate_h + 1; |
4644 | int stride_h = jcp.stride_h; |
4645 | auto filter_step_size = get_kernel_offset(0, jcp.kw); |
4646 | auto src_step_size = get_src_offset(0, 0, 1); |
4647 | auto ddst_step_size = get_ddst_offset(0, 1); |
4648 | Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end, |
4649 | oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label, |
4650 | oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift, |
4651 | oh_dilate_label_end, oh_dilate_setup_label_shift, |
4652 | oh_dilate_setup_label_noshift, oh_dilate_setup_label_end; |
4653 | |
4654 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
4655 | int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h); |
4656 | int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end); |
4657 | int oh_head_overflow_end = div_up(t_pad, stride_h); |
4658 | int oh_tail_end = jcp.oh; |
4659 | |
4660 | int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h; |
4661 | int ih_body_end |
4662 | = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset); |
4663 | |
4664 | if (is_partial) |
4665 | mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]); |
4666 | else |
4667 | xor_(reg_oj, reg_oj); |
4668 | |
4669 | /* Compute 'top' edge */ |
4670 | if (t_pad > 0) { |
4671 | if (is_partial) { |
4672 | cmp(reg_oj, oh_head_overflow_end); |
4673 | jge(oh_tpad_tail_label_end, T_NEAR); |
4674 | } |
4675 | const int overflow |
4676 | = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h)); |
4677 | const int underflow = div_up(t_pad, dilate_h); |
4678 | const int initial_kh = jcp.kh - overflow - underflow; |
4679 | |
4680 | // Setup reg_kh, reg_kernel, and reg_src |
4681 | mov(reg_kh, initial_kh); |
4682 | add(reg_kernel, filter_step_size * underflow); |
4683 | if (is_dilated) { |
4684 | const int tail = t_pad % dilate_h; |
4685 | const int shift = tail == 0 ? 0 : dilate_h - tail; |
4686 | mov(reg_ih_shift, shift); |
4687 | if (!is_partial) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift); |
4688 | add(reg_src, src_step_size * shift); |
4689 | } |
4690 | |
4691 | if (is_partial) { |
4692 | Label head_setup, head_setup_finish; |
4693 | cmp(reg_oj, 0); |
4694 | je(head_setup_finish, T_NEAR); |
4695 | mov(reg_oj_setup, reg_oj); |
4696 | |
4697 | L(head_setup); |
4698 | if (is_dilated) { |
4699 | inc(reg_ih_shift); |
4700 | cmp(reg_ih_shift, dilate_h); |
4701 | jl(oh_dilate_setup_label_shift, T_NEAR); |
4702 | // unshift src as new kernel element enters |
4703 | sub(reg_src, src_step_size * (dilate_h - 1)); |
4704 | xor_(reg_ih_shift, reg_ih_shift); |
4705 | } |
4706 | // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 |
4707 | add(reg_kh, stride_h); |
4708 | sub(reg_kernel, filter_step_size * stride_h); |
4709 | if (is_dilated) { |
4710 | jmp(oh_dilate_setup_label_noshift, T_NEAR); |
4711 | L(oh_dilate_setup_label_shift); |
4712 | // shift src as old kernel element progresses |
4713 | add(reg_src, src_step_size * stride_h); |
4714 | L(oh_dilate_setup_label_noshift); |
4715 | } |
4716 | sub(reg_oj_setup, 1); |
4717 | jg(head_setup, T_NEAR); |
4718 | L(head_setup_finish); |
4719 | |
4720 | if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift); |
4721 | if (oh_head_end < oh_head_overflow_end) { |
4722 | cmp(reg_oj, oh_head_end); |
4723 | jge(oh_tpad_label_end, T_NEAR); |
4724 | } |
4725 | } |
4726 | |
4727 | //Setup reg_kernel |
4728 | // If dilated, shift src ptr |
4729 | // Loop |
4730 | L(oh_tpad_label); |
4731 | compute_oh_step_common(nb_ic_blocking, nb_oc_blocking); |
4732 | add(reg_ddst, ddst_step_size); |
4733 | if (is_dilated) { |
4734 | mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]); |
4735 | inc(reg_ih_shift); |
4736 | mov(ptr[rsp + ih_dilate_offset], reg_ih_shift); |
4737 | cmp(reg_ih_shift, dilate_h); |
4738 | jl(oh_dilate_label_shift, T_NEAR); |
4739 | // unshift src as new kernel element enters |
4740 | sub(reg_src, src_step_size * (dilate_h - 1)); |
4741 | xor_(reg_ih_shift, reg_ih_shift); |
4742 | mov(ptr[rsp + ih_dilate_offset], reg_ih_shift); |
4743 | } |
4744 | // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 |
4745 | add(reg_kh, stride_h); |
4746 | sub(reg_kernel, filter_step_size * stride_h); |
4747 | if (is_dilated) { |
4748 | jmp(oh_dilate_label_noshift, T_NEAR); |
4749 | L(oh_dilate_label_shift); |
4750 | // shift src as old kernel element progresses |
4751 | add(reg_src, src_step_size * stride_h); |
4752 | L(oh_dilate_label_noshift); |
4753 | } |
4754 | inc(reg_oj); |
4755 | |
4756 | if (is_partial) { |
4757 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
4758 | jge(oh_bpad_label_end, T_NEAR); |
4759 | } |
4760 | cmp(reg_oj, oh_head_end); |
4761 | jl(oh_tpad_label, T_NEAR); |
4762 | |
4763 | L(oh_tpad_label_end); |
4764 | // need second loop to process kernel if it is larger than the src |
4765 | // (does not apply to dilations as they must have unit stride) |
4766 | if (oh_head_end < oh_head_overflow_end) { |
4767 | assert(!is_dilated); |
4768 | |
4769 | cmp(reg_oj, oh_head_overflow_end); |
4770 | jge(oh_tpad_tail_label_end, T_NEAR); |
4771 | |
4772 | mov(reg_kh, jcp.ih); |
4773 | L(oh_tpad_tail_label); |
4774 | { |
4775 | compute_oh_step_common(nb_ic_blocking, nb_oc_blocking); |
4776 | add(reg_ddst, ddst_step_size); |
4777 | sub(reg_kernel, filter_step_size * stride_h); |
4778 | |
4779 | inc(reg_oj); |
4780 | |
4781 | if (is_partial) { |
4782 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
4783 | jge(oh_bpad_label_end, T_NEAR); |
4784 | } |
4785 | cmp(reg_oj, oh_head_overflow_end); |
4786 | jl(oh_tpad_tail_label, T_NEAR); |
4787 | } |
4788 | } |
4789 | if (body_src_start_offset != 0) { |
4790 | add(reg_kernel, filter_step_size * body_src_start_offset); |
4791 | add(reg_src, src_step_size * body_src_start_offset); |
4792 | } |
4793 | L(oh_tpad_tail_label_end); |
4794 | } |
4795 | |
4796 | if (is_partial) { |
4797 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
4798 | jge(oh_bpad_label_end, T_NEAR); |
4799 | } |
4800 | cmp(reg_oj, oh_body_end); |
4801 | jge(oh_label_end, T_NEAR); |
4802 | |
4803 | /* Compute middle block(s) */ |
4804 | mov(reg_kh, jcp.kh); |
4805 | L(oh_label); |
4806 | { |
4807 | compute_oh_step_common(nb_ic_blocking, nb_oc_blocking); |
4808 | add(reg_src, src_step_size * stride_h); |
4809 | add(reg_ddst, ddst_step_size); |
4810 | |
4811 | inc(reg_oj); |
4812 | |
4813 | if (is_partial) { |
4814 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
4815 | jge(oh_bpad_label_end, T_NEAR); |
4816 | } |
4817 | |
4818 | cmp(reg_oj, oh_body_end); |
4819 | jl(oh_label, T_NEAR); |
4820 | } |
4821 | L(oh_label_end); |
4822 | |
4823 | /* Compute bottom edge */ |
4824 | if (b_pad > 0) { |
4825 | if (is_partial) { |
4826 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
4827 | jge(oh_bpad_label_end, T_NEAR); |
4828 | } |
4829 | cmp(reg_oj, jcp.oh); |
4830 | jge(oh_bpad_label_end, T_NEAR); |
4831 | |
4832 | if (is_dilated) { |
4833 | // Assumes unit stride for dilations |
4834 | mov(reg_kh, jcp.kh - 1); |
4835 | xor_(reg_ih_shift, reg_ih_shift); |
4836 | } else { |
4837 | assert(jcp.dilate_h == 0); |
4838 | mov(reg_kh, jcp.ih - ih_body_end); |
4839 | } |
4840 | if (is_partial) { |
4841 | lea(reg_oj_setup, |
4842 | ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]); |
4843 | if (stride_h == 1 && !is_dilated) { |
4844 | sub(reg_kh, reg_oj_setup); |
4845 | } else { |
4846 | Label body_setup, body_setup_finish, dilate_skip; |
4847 | cmp(reg_oj_setup, 0); |
4848 | je(body_setup_finish, T_NEAR); |
4849 | |
4850 | L(body_setup); |
4851 | if (is_dilated) { |
4852 | inc(reg_ih_shift); |
4853 | cmp(reg_ih_shift, dilate_h); |
4854 | jl(dilate_skip, T_NEAR); |
4855 | xor_(reg_ih_shift, reg_ih_shift); |
4856 | } |
4857 | sub(reg_kh, stride_h); |
4858 | L(dilate_skip); |
4859 | sub(reg_oj_setup, 1); |
4860 | jg(body_setup, T_NEAR); |
4861 | L(body_setup_finish); |
4862 | } |
4863 | } |
4864 | |
4865 | if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift); |
4866 | L(oh_bpad_label); |
4867 | { |
4868 | compute_oh_step_common(nb_ic_blocking, nb_oc_blocking); |
4869 | add(reg_src, src_step_size * stride_h); |
4870 | add(reg_ddst, ddst_step_size); |
4871 | |
4872 | if (is_dilated) { |
4873 | mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]); |
4874 | inc(reg_ih_shift); |
4875 | mov(ptr[rsp + ih_dilate_offset], reg_ih_shift); |
4876 | cmp(reg_ih_shift, dilate_h); |
4877 | jl(oh_dilate_label_end, T_NEAR); |
4878 | xor_(reg_ih_shift, reg_ih_shift); |
4879 | mov(ptr[rsp + ih_dilate_offset], reg_ih_shift); |
4880 | } |
4881 | sub(reg_kh, stride_h); |
4882 | L(oh_dilate_label_end); |
4883 | inc(reg_oj); |
4884 | if (is_partial) { |
4885 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
4886 | jge(oh_bpad_label_end, T_NEAR); |
4887 | } |
4888 | cmp(reg_oj, oh_tail_end); |
4889 | jl(oh_bpad_label, T_NEAR); |
4890 | } |
4891 | } |
4892 | L(oh_bpad_label_end); |
4893 | } |
4894 | |
4895 | void jit_avx512_core_amx_bwd_weights_kernel_t::compute_od_loop_common( |
4896 | int nb_ic_blocking, int nb_oc_blocking, bool is_partial) { |
4897 | assert(jcp.harness == harness_3d_reduction); |
4898 | |
4899 | const int src_backpad_overlap |
4900 | = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d); |
4901 | |
4902 | const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw); |
4903 | const auto src_shift = get_src_offset(0, 0, jcp.ih); |
4904 | const auto ddst_shift = get_ddst_offset(0, jcp.oh); |
4905 | |
4906 | const int kd_front_pad = nstl::max(0, jcp.f_pad); |
4907 | const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id); |
4908 | |
4909 | Label d_loop_label, loop_end_label, common_block_label, fpad_end_label, |
4910 | backpad_end_label, backpad_label; |
4911 | |
4912 | /* initially offset 'kd' by f_pad */ |
4913 | mov(reg_src_d, ptr[param + GET_OFF(src)]); |
4914 | mov(reg_ddst_d, ptr[param + GET_OFF(dst)]); |
4915 | |
4916 | if (is_partial) { |
4917 | add(reg_kernel, ptr[param + GET_OFF(kd_offset)]); |
4918 | mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]); |
4919 | mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]); |
4920 | } else { |
4921 | const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad; |
4922 | const int kd_offset = get_kernel_offset( |
4923 | 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw); |
4924 | add(reg_kernel, kd_offset); |
4925 | xor_(reg_d_index, reg_d_index); |
4926 | mov(reg_kd_count, kd_padding); |
4927 | } |
4928 | |
4929 | cmp(reg_kd_count, 0); |
4930 | jle(loop_end_label, T_NEAR); // no iterations along kd |
4931 | if (is_partial) |
4932 | cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]); |
4933 | else |
4934 | cmp(reg_d_index, jcp.od); |
4935 | jge(loop_end_label, T_NEAR); // no iterations along depth dimension |
4936 | |
4937 | L(d_loop_label); |
4938 | |
4939 | mov(reg_src, reg_src_d); |
4940 | mov(reg_ddst, reg_ddst_d); |
4941 | |
4942 | mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d); |
4943 | mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d); |
4944 | mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index); |
4945 | |
4946 | compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking); |
4947 | |
4948 | mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset)); |
4949 | mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset)); |
4950 | mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset)); |
4951 | |
4952 | /* Compute 'front' edge */ |
4953 | if (jcp.f_pad > 0) { |
4954 | /* Check if within fpad region */ |
4955 | cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d)); |
4956 | jge(fpad_end_label, T_NEAR); |
4957 | |
4958 | /* Fpad steps */ |
4959 | sub(reg_kernel, filter_shift * jcp.stride_d); |
4960 | add(reg_kd_count, jcp.stride_d); |
4961 | |
4962 | /* Final number of kernel elements that overlap with src */ |
4963 | const int src_ker_overlap = nstl::min(jcp.kd, jcp.id); |
4964 | cmp(reg_kd_count, src_ker_overlap); |
4965 | jle(common_block_label, T_NEAR); |
4966 | |
4967 | /* Correct any excess shifts to kernel and src */ |
4968 | if (jcp.f_pad <= jcp.od * jcp.stride_d) { |
4969 | /* Filter has moved beyond padding (adjust for stride effects) */ |
4970 | if (jcp.f_pad % jcp.stride_d != 0) { |
4971 | int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d; |
4972 | add(reg_kernel, filter_shift * src_corr); |
4973 | add(reg_src_d, src_shift * src_corr); |
4974 | } |
4975 | } else { |
4976 | /* Filter still overlaps padding (complete reset) */ |
4977 | sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift); |
4978 | } |
4979 | |
4980 | /* Apply correction */ |
4981 | mov(reg_kd_count, src_ker_overlap); |
4982 | jmp(common_block_label); |
4983 | |
4984 | L(fpad_end_label); |
4985 | } |
4986 | |
4987 | /* Compute bottom edge */ |
4988 | if (jcp.back_pad > 0) { |
4989 | |
4990 | /* Check if within back_pad region */ |
4991 | cmp(reg_d_index, src_backpad_overlap - 1); |
4992 | jl(backpad_end_label, T_NEAR); |
4993 | jg(backpad_label, T_NEAR); |
4994 | |
4995 | /* Execute overlap correction between the filter and the initial |
4996 | * back_pad region. */ |
4997 | mov(reg_kd_count, |
4998 | jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d); |
4999 | jmp(backpad_end_label, T_NEAR); |
5000 | |
5001 | L(backpad_label); |
5002 | sub(reg_kd_count, jcp.stride_d); |
5003 | cmp(reg_kd_count, 0); |
5004 | jle(loop_end_label, T_NEAR); |
5005 | |
5006 | L(backpad_end_label); |
5007 | } |
5008 | |
5009 | /* Compute middle block */ |
5010 | add(reg_src_d, src_shift * jcp.stride_d); |
5011 | |
5012 | /* Execute common block and loop */ |
5013 | L(common_block_label); |
5014 | add(reg_ddst_d, ddst_shift); |
5015 | inc(reg_d_index); |
5016 | if (is_partial) |
5017 | cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]); |
5018 | else |
5019 | cmp(reg_d_index, jcp.od); |
5020 | jl(d_loop_label, T_NEAR); |
5021 | |
5022 | L(loop_end_label); |
5023 | } |
5024 | |
5025 | void jit_avx512_core_amx_bwd_weights_kernel_t::compute_loop( |
5026 | int nb_ic_blocking, int nb_oc_blocking) { |
5027 | mov(reg_src, ptr[param + GET_OFF(src)]); |
5028 | mov(reg_ddst, ptr[param + GET_OFF(dst)]); |
5029 | mov(reg_kernel, ptr[param + GET_OFF(filt)]); |
5030 | |
5031 | maybe_zero_kernel(nb_ic_blocking, nb_oc_blocking); |
5032 | maybe_compute_diff_bias(nb_oc_blocking); |
5033 | |
5034 | switch (jcp.harness) { |
5035 | case harness_3d_reduction: |
5036 | compute_od_loop_common(nb_ic_blocking, nb_oc_blocking, true); |
5037 | break; |
5038 | case harness_2d_reduction: |
5039 | compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking, true); |
5040 | break; |
5041 | case harness_mb_reduction: |
5042 | compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking); |
5043 | break; |
5044 | case harness_compute_full_spatial: |
5045 | compute_full_spat_loop(nb_ic_blocking, nb_oc_blocking); |
5046 | break; |
5047 | default: assert(!"Invalid harness type" ); |
5048 | } |
5049 | } |
5050 | |
5051 | void jit_avx512_core_amx_bwd_weights_kernel_t::setup_stack_space() { |
5052 | kd_count_offset = ic_block_step_stack_size; |
5053 | src_d_offset = ic_block_step_stack_size + 8; |
5054 | ddst_d_offset = ic_block_step_stack_size + 16; |
5055 | d_index_offset = ic_block_step_stack_size + 24; |
5056 | ih_dilate_offset = ic_block_step_stack_size + 32; |
5057 | src_save_offset = ic_block_step_stack_size + 40; |
5058 | ddst_save_offset = ic_block_step_stack_size + 48; |
5059 | stack_space_needed = ic_block_step_stack_size + 56; |
5060 | } |
5061 | |
5062 | void jit_avx512_core_amx_bwd_weights_kernel_t::generate() { |
5063 | preamble(); |
5064 | |
5065 | setup_stack_space(); |
5066 | |
5067 | sub(rsp, stack_space_needed); |
5068 | |
5069 | Label last_ic_block_label, last_blocks_done_label; |
5070 | |
5071 | mov(reg_tmp, ptr[param + GET_OFF(last_ic_block)]); |
5072 | cmp(reg_tmp, 0); |
5073 | jne(last_ic_block_label, T_NEAR); |
5074 | { // full nb_ic_blocking |
5075 | Label last_oc_block_label; |
5076 | mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]); |
5077 | cmp(reg_tmp, 0); |
5078 | jne(last_oc_block_label, T_NEAR); |
5079 | { // full nb_oc_blocking |
5080 | compute_loop(jcp.nb_ic_blocking, jcp.nb_oc_blocking); |
5081 | jmp(last_blocks_done_label, T_NEAR); |
5082 | } |
5083 | L(last_oc_block_label); |
5084 | { // tail of nb_oc_blocking |
5085 | compute_loop(jcp.nb_ic_blocking, 1); |
5086 | jmp(last_blocks_done_label, T_NEAR); |
5087 | } |
5088 | } |
5089 | L(last_ic_block_label); |
5090 | { // tail nb_ic_blocking |
5091 | Label last_oc_block_label; |
5092 | mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]); |
5093 | cmp(reg_tmp, 0); |
5094 | jne(last_oc_block_label, T_NEAR); |
5095 | { // full nb_oc_blocking |
5096 | compute_loop(1, jcp.nb_oc_blocking); |
5097 | jmp(last_blocks_done_label, T_NEAR); |
5098 | } |
5099 | L(last_oc_block_label); |
5100 | { // tail of nb_oc_blocking |
5101 | compute_loop(1, 1); |
5102 | jmp(last_blocks_done_label, T_NEAR); |
5103 | } |
5104 | } |
5105 | |
5106 | L(last_blocks_done_label); |
5107 | add(rsp, stack_space_needed); |
5108 | |
5109 | postamble(); |
5110 | } |
5111 | |
5112 | status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_conf( |
5113 | jit_conv_conf_t &jcp, const convolution_desc_t &cd, |
5114 | memory_desc_t &src_md, memory_desc_t &diff_weights_md, |
5115 | memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) { |
5116 | const memory_desc_wrapper src_d(&src_md); |
5117 | const memory_desc_wrapper diff_weights_d(&diff_weights_md); |
5118 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
5119 | const memory_desc_wrapper diff_bias_d(&diff_bias_md); |
5120 | |
5121 | jcp = zero<decltype(jcp)>(); |
5122 | |
5123 | const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; |
5124 | int ndims = src_d.ndims(); |
5125 | |
5126 | if (!mayiuse(avx512_core_amx)) return status::unimplemented; |
5127 | jcp.isa = avx512_core_amx; |
5128 | |
5129 | jcp.has_vnni = true; // Needed for transpose routines |
5130 | jcp.nthr = nthreads; |
5131 | |
5132 | jcp.ndims = ndims; |
5133 | jcp.prop_kind = cd.prop_kind; |
5134 | |
5135 | jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; |
5136 | jcp.mb = src_d.dims()[0]; |
5137 | |
5138 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
5139 | jcp.oc_without_padding = jcp.oc; |
5140 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
5141 | |
5142 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
5143 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
5144 | jcp.iw = src_d.dims()[ndims - 1]; |
5145 | jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; |
5146 | jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2]; |
5147 | jcp.ow = diff_dst_d.dims()[ndims - 1]; |
5148 | |
5149 | jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; |
5150 | jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2]; |
5151 | jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1]; |
5152 | |
5153 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
5154 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
5155 | jcp.l_pad = cd.padding[0][ndims - 3]; |
5156 | |
5157 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
5158 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
5159 | jcp.stride_w = cd.strides[ndims - 3]; |
5160 | |
5161 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
5162 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
5163 | jcp.dilate_w = cd.dilates[ndims - 3]; |
5164 | |
5165 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
5166 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
5167 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
5168 | |
5169 | bool ok = true |
5170 | // general condition to simplify dilations |
5171 | && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1) |
5172 | && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1) |
5173 | && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1) |
5174 | // special condition to simplify dilations in compute_oh_loop_common |
5175 | && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih); |
5176 | if (!ok) return status::unimplemented; |
5177 | |
5178 | ok = true && one_of(ndims, 3, 4, 5) |
5179 | && everyone_is( |
5180 | data_type::bf16, src_d.data_type(), diff_dst_d.data_type()) |
5181 | && one_of(diff_weights_d.data_type(), data_type::f32, |
5182 | data_type::bf16); |
5183 | if (!ok) return status::unimplemented; |
5184 | |
5185 | jcp.transform_to_vnni = diff_weights_d.data_type() == data_type::bf16; |
5186 | |
5187 | jcp.r_pad = nstl::max(0, |
5188 | calculate_end_padding( |
5189 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw)); |
5190 | jcp.b_pad = nstl::max(0, |
5191 | calculate_end_padding( |
5192 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh)); |
5193 | jcp.back_pad = nstl::max(0, |
5194 | calculate_end_padding( |
5195 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd)); |
5196 | |
5197 | /* XXX: no support for padding when dilation_d > 0 */ |
5198 | if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad))) |
5199 | return status::unimplemented; |
5200 | |
5201 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
5202 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
5203 | jcp.ohp = jcp.oh; |
5204 | jcp.owp = jcp.ow; |
5205 | |
5206 | jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); |
5207 | if (jcp.is_depthwise) |
5208 | return status::unimplemented; // TODO: add support of DW convolution |
5209 | |
5210 | const int dat_format_tag = ndims - 3; |
5211 | format_tag_t dat_tag_nspc = utils::pick(dat_format_tag, format_tag::nwc, |
5212 | format_tag::nhwc, format_tag::ndhwc); |
5213 | format_tag_t dat_tag_opt = dat_tag_nspc; |
5214 | |
5215 | if (src_d.format_kind() == format_kind::any) { |
5216 | CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt)); |
5217 | jcp.src_tag = dat_tag_opt; |
5218 | } else |
5219 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag_opt); |
5220 | if (!one_of(jcp.src_tag, dat_tag_opt)) return status::unimplemented; |
5221 | jcp.is_nspc = jcp.src_tag == dat_tag_nspc; |
5222 | |
5223 | if (diff_dst_d.format_kind() == format_kind::any) { |
5224 | CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag)); |
5225 | jcp.dst_tag = jcp.src_tag; |
5226 | } else |
5227 | jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag); |
5228 | if (jcp.dst_tag != jcp.src_tag) return status::unimplemented; |
5229 | |
5230 | if (!jcp.is_nspc) return status::unimplemented; |
5231 | |
5232 | const int wei_format_tag = 2 * ndims - 6 + with_groups; |
5233 | format_tag_t wei_tag; |
5234 | if (jcp.transform_to_vnni) |
5235 | wei_tag = pick(wei_format_tag, format_tag::OIw16i16o2i, |
5236 | format_tag::gOIw16i16o2i, format_tag::OIhw16i16o2i, |
5237 | format_tag::gOIhw16i16o2i, format_tag::OIdhw16i16o2i, |
5238 | format_tag::gOIdhw16i16o2i); |
5239 | else |
5240 | wei_tag = pick(wei_format_tag, format_tag::OIw16i16o, |
5241 | format_tag::gOIw16i16o, format_tag::OIhw16i16o, |
5242 | format_tag::gOIhw16i16o, format_tag::OIdhw16i16o, |
5243 | format_tag::gOIdhw16i16o); |
5244 | if (diff_weights_md.format_kind == format_kind::any) { |
5245 | CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); |
5246 | jcp.wei_tag = wei_tag; |
5247 | } else { |
5248 | jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); |
5249 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
5250 | } |
5251 | jcp.wei_dt = diff_weights_d.data_type(); |
5252 | |
5253 | /* conditions on bias memory */ |
5254 | jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; |
5255 | if (jcp.with_bias) { |
5256 | if (diff_bias_d.format_kind() == format_kind::any) |
5257 | CHECK(memory_desc_init_by_tag(diff_bias_md, format_tag::x)); |
5258 | } |
5259 | jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef; |
5260 | jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0; |
5261 | |
5262 | /* kernel applicability check wrt boundaries |
5263 | * the conditions are quite general across the kernels we have, |
5264 | * but ideally the check should belong to a specific kernel... */ |
5265 | const int max_pad_h = ext_kh / 2; |
5266 | const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw |
5267 | && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h |
5268 | && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd; |
5269 | if (!boundaries_ok) return status::unimplemented; |
5270 | |
5271 | jcp.ic_block = 16; |
5272 | jcp.oc_block = 16; |
5273 | |
5274 | jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block); |
5275 | jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block); |
5276 | |
5277 | jcp.ic_tail = jcp.ic % jcp.ic_block; |
5278 | jcp.oc_tail = jcp.oc % jcp.oc_block; |
5279 | |
5280 | jcp.nb_oc_blocking = (jcp.nb_oc > 1) ? 2 : 1; |
5281 | jcp.nb_ic_blocking = (jcp.nb_ic > 1) ? 2 : 1; |
5282 | |
5283 | const int target_palette = amx::get_target_palette(); |
5284 | jcp.max_tiles = amx::get_max_tiles(target_palette); |
5285 | jcp.full_tile_width = amx::get_max_rows(target_palette); |
5286 | |
5287 | if (jcp.max_tiles != 8 || jcp.full_tile_width != 16) |
5288 | return status::unimplemented; |
5289 | |
5290 | const bool is_2d = (ndims == 4); |
5291 | const bool is_3d = (ndims == 5); |
5292 | jcp.typesize_in = sizeof(bfloat16_t); |
5293 | jcp.typesize_out = sizeof(float); |
5294 | |
5295 | // TODO: Find more shapes (especially 3D with large spatials) for which |
5296 | // local transposition will be beneficial. Furthermore, for TBB threads |
5297 | // more shapes can potentially benefit from spatial blocking |
5298 | int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow; |
5299 | |
5300 | jcp.global_transpose = dnnl_thr_syncable(); |
5301 | jcp.spatial_blk_size = optimal_blk_size; |
5302 | |
5303 | const int tr_round = 32; // To load full tile register |
5304 | int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round); |
5305 | jcp.tr_iw = rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round) |
5306 | * jcp.stride_w; |
5307 | |
5308 | jcp.tr_src_num_guard_elems = tr_pad; // upper bound |
5309 | jcp.tr_ow = rnd_up(jcp.ow, 2); |
5310 | |
5311 | if (jcp.tr_ow <= max_ur_w) { |
5312 | jcp.ur_w = jcp.tr_ow; |
5313 | jcp.ur_w_blocks = 1; |
5314 | } else { |
5315 | jcp.ur_w = 1; |
5316 | for (int i = max_ur_w; i >= 1; i -= 2) { |
5317 | if (jcp.tr_ow % i == 0) { |
5318 | jcp.ur_w = i; |
5319 | break; |
5320 | } |
5321 | } |
5322 | jcp.ur_w_blocks = jcp.tr_ow / jcp.ur_w; |
5323 | } |
5324 | |
5325 | bool args_ok = true && jcp.ic <= src_d.padded_dims()[1] |
5326 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
5327 | && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] |
5328 | && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; |
5329 | if (!args_ok) return status::unimplemented; |
5330 | |
5331 | bool use_full_spat_loop = jcp.ndims < 5 && jcp.ih == jcp.oh |
5332 | && jcp.iw == jcp.ow && everyone_is(1, jcp.stride_h, jcp.stride_w) |
5333 | && everyone_is(0, jcp.dilate_h, jcp.dilate_w) |
5334 | // TODO: Remove this constraint: only 3x3 kernel works now |
5335 | && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2 |
5336 | && one_of(1, jcp.l_pad, jcp.r_pad) && jcp.kh == jcp.kw |
5337 | && jcp.ih >= jcp.kh && jcp.iw >= jcp.kw; |
5338 | |
5339 | jcp.harness = ndims == 5 |
5340 | ? harness_3d_reduction |
5341 | : (use_full_spat_loop ? harness_compute_full_spatial |
5342 | : (ndims == 4) ? harness_2d_reduction |
5343 | : harness_mb_reduction); |
5344 | switch (jcp.harness) { |
5345 | case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break; |
5346 | case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break; |
5347 | case harness_compute_full_spatial: |
5348 | case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break; |
5349 | default: assert(!"Invalid harness" ); jcp.nthr_mb_work = jcp.mb; |
5350 | } |
5351 | { // balancing |
5352 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; |
5353 | balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b); |
5354 | jcp.nthr = nthr; |
5355 | jcp.nthr_mb = nthr_mb; |
5356 | jcp.nthr_g = nthr_g; |
5357 | jcp.nthr_oc_b = nthr_oc_b; |
5358 | jcp.nthr_ic_b = nthr_ic_b; |
5359 | |
5360 | // TODO: Optimize memory allocation when threaded on height and depth |
5361 | jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id; |
5362 | jcp.tr_src_buf_count = jcp.global_transpose |
5363 | ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups |
5364 | : jcp.nthr; |
5365 | |
5366 | jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od; |
5367 | jcp.tr_diff_dst_buf_count = jcp.global_transpose |
5368 | ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups |
5369 | : jcp.nthr; |
5370 | } |
5371 | |
5372 | return status::success; |
5373 | } |
5374 | |
5375 | status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_scratchpad( |
5376 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, |
5377 | memory_desc_t &src_md, memory_desc_t &diff_weights_md, |
5378 | memory_desc_t &diff_dst_md) { |
5379 | const memory_desc_wrapper src_d(&src_md); |
5380 | const memory_desc_wrapper diff_weights_d(&diff_weights_md); |
5381 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
5382 | |
5383 | // XXX: See the comment about tr_iw and guarding elements in |
5384 | // jit_avx512_core_amx_bwd_weights_kernel_t::init_conf() |
5385 | const size_t tr_src_size |
5386 | = (jcp.tr_src_buf_count * jcp.tr_src_buf_size * jcp.nb_ic_blocking) |
5387 | + jcp.tr_src_num_guard_elems; |
5388 | scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in); |
5389 | |
5390 | /* prepare synchronization contexts */ |
5391 | if (jcp.global_transpose && jcp.nthr_oc_b > 1) { |
5392 | const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; |
5393 | scratchpad.book<simple_barrier::ctx_t>( |
5394 | key_conv_tr_src_bctx, tr_src_bctx_size); |
5395 | } |
5396 | |
5397 | const size_t tr_diff_dst_size = jcp.tr_diff_dst_buf_count |
5398 | * jcp.tr_diff_dst_buf_size * jcp.nb_oc_blocking; |
5399 | |
5400 | const size_t min_align = 64; |
5401 | scratchpad.book( |
5402 | key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in, min_align); |
5403 | |
5404 | /* prepare synchronization contexts */ |
5405 | if (jcp.global_transpose && jcp.nthr_ic_b > 1) { |
5406 | const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b; |
5407 | scratchpad.book<simple_barrier::ctx_t>( |
5408 | key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size); |
5409 | } |
5410 | |
5411 | if (IMPLICATION(jcp.nthr_mb == 1, |
5412 | (jcp.with_bias && jcp.bia_dt == data_type::bf16) |
5413 | || jcp.wei_dt == data_type::bf16)) { |
5414 | const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block |
5415 | * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
5416 | const size_t bia_size |
5417 | = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
5418 | |
5419 | const int num_wei_buffers |
5420 | = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1; |
5421 | const int num_bia_buffers = jcp.with_bias |
5422 | ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb |
5423 | : jcp.nthr_mb - 1) |
5424 | : 0; |
5425 | |
5426 | const size_t wei_bia_reduction_size |
5427 | = wei_size * num_wei_buffers + bia_size * num_bia_buffers; |
5428 | |
5429 | scratchpad.book<float>( |
5430 | key_conv_wei_bia_reduction, wei_bia_reduction_size); |
5431 | |
5432 | scratchpad.book<simple_barrier::ctx_t>( |
5433 | key_conv_wei_bia_reduction_bctx, 1); |
5434 | } |
5435 | |
5436 | if (jcp.with_bias |
5437 | && ((jcp.oc_without_padding % jcp.oc_block != 0) |
5438 | && jcp.bia_dt == data_type::f32)) { |
5439 | scratchpad.book(key_conv_padded_bias, |
5440 | jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia); |
5441 | } |
5442 | scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline |
5443 | |
5444 | constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32 |
5445 | << 30; // 32Gb - TODO: may it's too large? |
5446 | const size_t scratchpad_limit_by_tensor_sizes = (size_t)32 * jcp.nthr |
5447 | * (src_d.size() + diff_weights_d.size() + diff_dst_d.size()); |
5448 | const size_t scratchpad_limit |
5449 | = nstl::min(scratchpad_limit_by_absolute_value, |
5450 | scratchpad_limit_by_tensor_sizes); |
5451 | if (scratchpad.size() > scratchpad_limit) |
5452 | return status::unimplemented; |
5453 | else |
5454 | return status::success; |
5455 | } |
5456 | |
5457 | void jit_avx512_core_amx_bwd_weights_kernel_t::balance(const jit_conv_conf_t &j, |
5458 | int &nthr_, int &nthr_mb_, int &nthr_g_, int &nthr_oc_b_, |
5459 | int &nthr_ic_b_) { |
5460 | nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; |
5461 | |
5462 | const int max_threads = dnnl_get_max_threads(); |
5463 | |
5464 | if (max_threads < j.ngroups) { |
5465 | /* simplification... fortunately it doesn't hurt much */ |
5466 | nthr_ = nthr_g_ = max_threads; |
5467 | return; |
5468 | } |
5469 | |
5470 | nthr_g_ = j.ngroups; |
5471 | const int nthr = max_threads / nthr_g_; |
5472 | |
5473 | auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { |
5474 | /* calculate per thread memory cost (read/write). high level optimizer |
5475 | * tries to minimize memory consumption. few notes: |
5476 | * (n1) if weights tensor size is less than source and destination |
5477 | * tensors we apply the ratio of the source and destination |
5478 | * tensor sizes to weights one as compensation coefficient to |
5479 | * avoid parallelization across batch size only, othervise we |
5480 | * apply additional coefficient to source component based on |
5481 | * performance measurements |
5482 | * (n2) use scales based on output vs input channels ratio for source |
5483 | * and destination componets to imporve threading balance across |
5484 | * input and output channels */ |
5485 | |
5486 | const dim_t src_type_size = 2; |
5487 | const dim_t wei_type_size = 4; |
5488 | |
5489 | dim_t src_size |
5490 | = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size; |
5491 | dim_t dst_size |
5492 | = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size; |
5493 | dim_t wei_size |
5494 | = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size; |
5495 | |
5496 | float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size; |
5497 | float oi_channels_ratio = (float)(j.nb_oc / j.nb_oc_blocking) |
5498 | / (j.nb_ic / j.nb_ic_blocking); |
5499 | auto get_src_coef = [=]() { |
5500 | float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f); |
5501 | if (wei_compensation_scale < 1.0f) src_coef *= 4.0f; |
5502 | |
5503 | return src_coef; |
5504 | }; |
5505 | |
5506 | auto get_dst_coef |
5507 | = [=]() { return nstl::max(oi_channels_ratio, 1.0f); }; |
5508 | |
5509 | auto get_wei_coef |
5510 | = [=]() { return nstl::max(wei_compensation_scale, 1.0f); }; |
5511 | |
5512 | const float src_coef = get_src_coef(); |
5513 | const float dst_coef = get_dst_coef(); |
5514 | const float wei_coef = get_wei_coef(); |
5515 | |
5516 | float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb) |
5517 | * div_up(j.ngroups, nthr_g_) |
5518 | * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.mb |
5519 | * (j.ic_block * j.nb_ic_blocking) * j.id * j.ih * j.tr_iw |
5520 | / j.nthr_mb_work / j.stride_d / j.stride_h / j.stride_w; |
5521 | float wei_v = wei_coef * div_up(j.ngroups, nthr_g_) |
5522 | * div_up((j.nb_oc / j.nb_oc_blocking), |
5523 | (j.oc_block * j.nb_oc_blocking) * nthr_oc_b) |
5524 | * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.kh * j.kw |
5525 | * j.kd * (j.ic_block * j.nb_ic_blocking) |
5526 | * (j.oc_block * j.nb_oc_blocking); |
5527 | float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb) |
5528 | * div_up(j.ngroups, nthr_g_) |
5529 | * div_up((j.nb_oc / j.nb_oc_blocking), |
5530 | (j.oc_block * j.nb_oc_blocking) * nthr_oc_b) |
5531 | * j.mb * (j.oc_block * j.nb_oc_blocking) * j.od * j.oh * j.tr_ow |
5532 | / j.nthr_mb_work; |
5533 | |
5534 | return src_v + dst_v + wei_v; |
5535 | }; |
5536 | |
5537 | float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); |
5538 | |
5539 | /* find the best thread distribution with lowest memory cost */ |
5540 | |
5541 | const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work); |
5542 | for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { |
5543 | const int nthr_par = nthr / nthr_mb; |
5544 | const int nthr_oc_b_max = nstl::min(nthr_par, |
5545 | (j.nb_oc / j.nb_oc_blocking)); // Amount of nb_oc_blocks |
5546 | for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { |
5547 | int nthr_ic_b = nstl::min( |
5548 | nthr_par / nthr_oc_b, (j.nb_ic / j.nb_ic_blocking)); |
5549 | |
5550 | float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); |
5551 | if (mem_cost <= best_mem_cost) { |
5552 | best_mem_cost = mem_cost; |
5553 | nthr_mb_ = nthr_mb; |
5554 | nthr_oc_b_ = nthr_oc_b; |
5555 | nthr_ic_b_ = nthr_ic_b; |
5556 | } |
5557 | } |
5558 | } |
5559 | |
5560 | if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr) |
5561 | nthr_mb_ = nstl::min(j.nthr_mb_work, nthr); |
5562 | nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; |
5563 | |
5564 | assert(nthr_ <= max_threads); |
5565 | } |
5566 | |
5567 | // start of diff_bias kernel |
5568 | void jit_avx512_core_amx_bwd_bias_kernel_t::compute_diff_bias_row(int ocb) { |
5569 | |
5570 | auto compute_step = [&]() { |
5571 | if (jcp.ddst_dt == data_type::bf16) { |
5572 | vmovups(vreg_bias_ddst, ptr[reg_ddst]); |
5573 | vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit); |
5574 | } else if (jcp.ddst_dt == data_type::f16) { |
5575 | // The ddst_dt is in vnni format, (S16c2s) which needs to be |
5576 | // reduced along S dimension. Since, we do not have f16_vnni |
5577 | // instruction, we try to emulate it. |
5578 | // ddst_data: [a,a, b,b ... p,p] in f16 |
5579 | // req_output = [a+a, b+b, ... p+p] in f32 i.e., [A, B, ... P] |
5580 | |
5581 | // [d,d, c,c, b,b, a,a] in f32 from now on |
5582 | vcvtph2psx(yreg_bias_ddst0, ptr[reg_ddst]); |
5583 | // [h,h, g,g, f,f, e,e] in f32 from now on |
5584 | vcvtph2psx(yreg_bias_ddst1, ptr[reg_ddst + 16]); |
5585 | // [h+h, g+g, d+d, c+c, f+f, e+e, b+b, a+a] i.e., [H, G, D, C, F, E, B, A] |
5586 | vhaddps(yreg_bias_ddst0, yreg_bias_ddst0, yreg_bias_ddst1); |
5587 | // accumulate with previous data |
5588 | vaddps(yreg_bias_acc0, yreg_bias_acc0, yreg_bias_ddst0); |
5589 | |
5590 | vcvtph2psx(yreg_bias_ddst0, ptr[reg_ddst + 32]); |
5591 | vcvtph2psx(yreg_bias_ddst1, ptr[reg_ddst + 48]); |
5592 | // [p+p, o+o, l+l, k+k, n+n, m+m, j+j, i+i] i.e., [P, O, L, K, N, M, J, I] |
5593 | vhaddps(yreg_bias_ddst0, yreg_bias_ddst0, yreg_bias_ddst1); |
5594 | vaddps(yreg_bias_acc1, yreg_bias_acc1, yreg_bias_ddst0); |
5595 | } |
5596 | }; |
5597 | |
5598 | Label ow_loop; |
5599 | int niters = jcp.tr_ow / 2; |
5600 | if (niters > 0) { |
5601 | mov(reg_tmp, jcp.tr_ow / 2); |
5602 | L(ow_loop); |
5603 | compute_step(); |
5604 | add(reg_ddst, get_ddst_offset(2)); |
5605 | sub(reg_tmp, 1); |
5606 | jnz(ow_loop, T_NEAR); |
5607 | } |
5608 | if (jcp.tr_ow % 2) compute_step(); |
5609 | |
5610 | if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters)); |
5611 | } |
5612 | |
5613 | void jit_avx512_core_amx_bwd_bias_kernel_t::compute_diff_bias( |
5614 | int nb_oc_blocking) { |
5615 | |
5616 | for (int ocb = 0; ocb < nb_oc_blocking; ocb++) { |
5617 | Label bias_loop; |
5618 | |
5619 | // pointer to diff_dst |
5620 | mov(reg_ddst, ptr[param + GET_OFF(dst)]); |
5621 | add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size); |
5622 | |
5623 | // number of rows |
5624 | mov(reg_oj, reg_nrows); |
5625 | |
5626 | // accumulator initialization |
5627 | if (jcp.ddst_dt == data_type::f16) { |
5628 | vpxord(yreg_bias_acc0, yreg_bias_acc0, yreg_bias_acc0); |
5629 | vpxord(yreg_bias_acc1, yreg_bias_acc1, yreg_bias_acc1); |
5630 | } else { |
5631 | vpxord(vreg_bias_acc, vreg_bias_acc, vreg_bias_acc); |
5632 | } |
5633 | cmp(reg_initial, 0); |
5634 | jnz(bias_loop, T_NEAR); |
5635 | const size_t offset = sizeof(float) * ocb * jcp.oc_block; |
5636 | if (jcp.ddst_dt == data_type::f16) { |
5637 | // the data is in plain format, transform while loading. |
5638 | // i.e.,[H, G, F, E, D, C, B, A] -> [H, G, D, C, F, E, B, A] |
5639 | // and [P, O, N, M, L, K, J, I] -> [P, O, L, K, N, M, J, I] |
5640 | vpermq(yreg_bias_acc0, ptr[reg_bias + offset], 0xd8); |
5641 | vpermq(yreg_bias_acc1, |
5642 | ptr[reg_bias + offset + vreg_traits<Ymm>::vlen], 0xd8); |
5643 | } else { |
5644 | vmovups(vreg_bias_acc, ptr[reg_bias + offset]); |
5645 | } |
5646 | // loop by rows |
5647 | L(bias_loop); |
5648 | { |
5649 | compute_diff_bias_row(ocb); |
5650 | add(reg_ddst, get_ddst_offset(0, 1)); |
5651 | |
5652 | sub(reg_oj, 1); |
5653 | jnz(bias_loop, T_NEAR); |
5654 | } |
5655 | |
5656 | // store accumulator |
5657 | if (jcp.ddst_dt == data_type::bf16) { |
5658 | vmovups(ptr[reg_bias + offset], vreg_bias_acc); |
5659 | } else if (jcp.ddst_dt == data_type::f16) { |
5660 | // transform to plain before storing. |
5661 | // i.e., [H, G, D, C, F, E, B, A] -> [H, G, F, E, D, C, B, A] |
5662 | // and [P, O, L, K, N, M, J, I] -> [P, O, N, M, L, K, J, I] |
5663 | vpermq(yreg_bias_acc0, yreg_bias_acc0, 0xd8); |
5664 | vpermq(yreg_bias_acc1, yreg_bias_acc1, 0xd8); |
5665 | vmovups(ptr[reg_bias + offset], yreg_bias_acc0); |
5666 | vmovups(ptr[reg_bias + offset + vreg_traits<Ymm>::vlen], |
5667 | yreg_bias_acc1); |
5668 | } |
5669 | } |
5670 | } |
5671 | |
5672 | void jit_avx512_core_amx_bwd_bias_kernel_t::generate() { |
5673 | preamble(); |
5674 | |
5675 | Label end_label; |
5676 | // number of rows |
5677 | mov(reg_nrows, ptr[param + GET_OFF(os_index_end)]); |
5678 | sub(reg_nrows, ptr[param + GET_OFF(os_index_begin)]); |
5679 | cmp(reg_nrows, 0); |
5680 | jle(end_label, T_NEAR); // nothing to do |
5681 | |
5682 | if (jcp.ddst_dt == data_type::bf16) { |
5683 | auto reg_unit_val = reg_tmp.cvt16(); |
5684 | mov(reg_unit_val, 0x3f80); // bf16 value of 1. |
5685 | vpbroadcastw(vreg_bias_unit, reg_unit_val); |
5686 | } |
5687 | mov(reg_bias, ptr[param + GET_OFF(bias)]); |
5688 | mov(reg_initial, ptr[param + GET_OFF(channel)]); |
5689 | |
5690 | Label last_oc_block_label; |
5691 | mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]); |
5692 | cmp(reg_tmp, 0); |
5693 | jne(last_oc_block_label, T_NEAR); |
5694 | { // full nb_oc_blocking |
5695 | compute_diff_bias(jcp.nb_oc_blocking); |
5696 | jmp(end_label, T_NEAR); |
5697 | } |
5698 | L(last_oc_block_label); |
5699 | { // tail of nb_oc_blocking |
5700 | compute_diff_bias(1); |
5701 | jmp(end_label, T_NEAR); |
5702 | } |
5703 | |
5704 | L(end_label); |
5705 | |
5706 | postamble(); |
5707 | } |
5708 | // end of diff_bias kernel |
5709 | |
5710 | } // namespace x64 |
5711 | } // namespace cpu |
5712 | } // namespace impl |
5713 | } // namespace dnnl |
5714 | |
5715 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
5716 | |