1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/nstl.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/platform.hpp"
23#include "cpu/x64/cpu_barrier.hpp"
24#include "cpu/x64/injectors/injector_utils.hpp"
25#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
26#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
27#include "cpu/x64/jit_avx512_common_conv_kernel.hpp"
28
29#define GET_OFF(field) offsetof(jit_conv_call_s, field)
30#define KNx_L2_EFFECTIVE_CAPACITY ((512 - 64) * 1024)
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37using namespace dnnl::impl::format_tag;
38using namespace dnnl::impl::memory_tracking::names;
39using namespace dnnl::impl::utils;
40using namespace Xbyak;
41
42namespace {
43
44constexpr auto small_spatial = 14;
45
46inline void pick_loop_order(jit_conv_conf_t &jcp) {
47 using namespace prop_kind;
48 assert(one_of(
49 jcp.prop_kind, forward_training, forward_inference, backward_data));
50 auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
51 auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
52
53 // The w in the loop order is currently ignored by 3D BWD_D
54 jcp.loop_order = (w <= small_spatial && h <= small_spatial) ? loop_cwgn
55 : loop_gncw;
56 if (utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
57 format_tag::nwc)
58 && jcp.ngroups > 1 && jcp.oc < 16)
59 jcp.loop_order = loop_nhwcg;
60}
61
62inline status_t init_tag(format_tag_t &tag, memory_desc_t &md,
63 const memory_desc_wrapper &mdw, const format_tag_t tag_value) {
64 if (mdw.format_kind() == format_kind::any) {
65 CHECK(memory_desc_init_by_tag(md, tag_value));
66 tag = tag_value;
67 } else {
68 tag = mdw.matches_one_of_tag(tag_value);
69 }
70
71 if (tag != tag_value) return status::unimplemented;
72
73 return status::success;
74}
75
76inline bool is_1stconv(const jit_conv_conf_t &jcp) {
77 if (mayiuse(avx512_core))
78 return (jcp.ic < 16 && jcp.ngroups == 1);
79 else
80 return one_of(jcp.ic, 1, 3);
81}
82
83inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
84 return (jcp.nb_ow > 1);
85}
86
87inline bool is_iw_threading_on(const jit_conv_conf_t &jcp) {
88 return (jcp.nb_iw > 1);
89}
90
91} // namespace
92
93template <typename Vmm>
94_jit_avx512_common_conv_fwd_kernel<Vmm>::_jit_avx512_common_conv_fwd_kernel(
95 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
96 const memory_desc_t &dst_md)
97 : jit_generator(jit_name()), jcp(ajcp), attr_(attr) {
98 if (jcp.with_eltwise || jcp.with_binary) {
99 using namespace binary_injector;
100 static constexpr bool preserve_gpr = true;
101 static constexpr bool preserve_vmm = false;
102 static constexpr size_t helper_vmm_idx = 31;
103 const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
104 static constexpr bool use_exact_tail_scalar_bcast = false;
105
106 const binary_injector::rhs_arg_static_params_t rhs_args_static_params {
107 helper_vmm_idx, reg_tmp, r15, r14, preserve_gpr, preserve_vmm,
108 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
109 memory_desc_wrapper(dst_md), tail_size, postops_mask,
110 use_exact_tail_scalar_bcast};
111 const binary_injector::static_params_t static_params {
112 this->param1, rhs_args_static_params};
113
114 postops_injector_ = utils::make_unique<
115 injector::jit_uni_postops_injector_t<avx512_core>>(
116 this, jcp.post_ops, static_params);
117 }
118}
119
120template <typename Vmm>
121void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w) {
122 for (int k = 0; k < jcp.nb_oc_blocking; k++)
123 for (int j = 0; j < ur_w; j++) {
124 Vmm vmm = vmm_out(j, k);
125 vpxord(vmm, vmm, vmm);
126 }
127}
128
129template <typename F>
130static void iterate(const int nb_oc_blocking, const int ur_w,
131 const bool oc_tail, const bool force_masking, const F &fun) {
132 for (int i_load = 0; i_load < nb_oc_blocking; i_load++) {
133 const auto mask_flag
134 = force_masking || (oc_tail && i_load + 1 == nb_oc_blocking);
135 for (int i_ur = 0; i_ur < ur_w; i_ur++)
136 fun(mask_flag, i_load, i_ur);
137 }
138}
139template <typename F>
140static void iterate(const int nb_oc_blocking, const int ur_w, const F &fun) {
141 iterate(nb_oc_blocking, ur_w, false, false, fun);
142}
143
144template <typename Vmm>
145void _jit_avx512_common_conv_fwd_kernel<Vmm>::apply_postops(int ur_w) {
146 injector_utils::vmm_index_set_t vmm_idxs;
147 if (jcp.with_binary) {
148 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
149 const bool mask_tail = jcp.oc_without_padding % jcp.simd_w;
150 const bool oc_blk_is_smaller_than_vmm = jcp.oc_block < isa_simd_width_;
151 iterate(jcp.nb_oc_blocking, ur_w, mask_tail, oc_blk_is_smaller_than_vmm,
152 [&](const bool mask_flag, const int i_load, const int i_ur) {
153 const size_t aux_output_l_off
154 = get_output_offset(i_ur, i_load);
155 const auto vmm_idx = vmm_out_idx(i_ur, i_load);
156 vmm_idxs.emplace(vmm_idx);
157
158 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_out);
159 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
160 vmm_idx, aux_output_l_off);
161 if (mask_flag) {
162 rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
163 }
164 });
165
166 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
167 } else {
168 iterate(jcp.nb_oc_blocking, ur_w,
169 [&](const bool, const int i_load, const int i_ur) {
170 vmm_idxs.emplace(vmm_out_idx(i_ur, i_load));
171 });
172 postops_injector_->compute_vector_range(vmm_idxs);
173 }
174}
175
176template <typename Vmm>
177void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w) {
178 Label no_update_label, store_label, post_ops_label;
179
180 mov(reg_channel, ptr[param1 + GET_OFF(flags)]);
181
182 if (jcp.with_bias) { mov(reg_bias, ptr[param1 + GET_OFF(bias)]); }
183 const int oc_tail = jcp.oc_tail;
184
185 if (!jcp.with_sum) {
186 test(reg_channel, FLAG_IC_FIRST);
187 jnz(no_update_label, T_NEAR);
188 }
189
190 for (int k = 0; k < jcp.nb_oc_blocking; k++)
191 for (int j = 0; j < ur_w; j++) {
192 Vmm vmm = vmm_out(j, k);
193 // mask only needed for last oc_block
194 if (oc_tail && k + 1 == jcp.nb_oc_blocking)
195 vmm = vmm | k_oc_tail_mask | T_z;
196 size_t aux_output_offset = get_output_offset(j, k);
197 vaddps(vmm,
198 make_safe_addr(
199 reg_out, aux_output_offset, reg_out_long_offt));
200 }
201
202 if (!jcp.with_sum) {
203 jmp(post_ops_label, T_NEAR);
204 } else {
205 test(reg_channel, FLAG_IC_FIRST);
206 jz(post_ops_label, T_NEAR);
207 }
208
209 L(no_update_label);
210 if (jcp.with_bias) {
211 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
212 int bias_offset = jcp.typesize_out * k * jcp.oc_block;
213 for (int j = 0; j < ur_w; j++) {
214 Vmm vmm = vmm_out(j, k);
215 // mask only needed for last oc_block
216 if (oc_tail && k + 1 == jcp.nb_oc_blocking)
217 vmm = vmm | k_oc_tail_mask | T_z;
218 vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset));
219 }
220 }
221 }
222
223 L(post_ops_label);
224
225 if (jcp.with_eltwise || jcp.with_binary) {
226 test(reg_channel, FLAG_IC_LAST);
227 jz(store_label, T_NEAR);
228
229 apply_postops(ur_w);
230 }
231
232 L(store_label);
233
234 const auto is_padding = jcp.oc_without_padding != jcp.oc;
235 for (int k = 0; k < jcp.nb_oc_blocking; k++)
236 for (int j = 0; j < ur_w; j++) {
237 Vmm vmm = vmm_out(j, k);
238 // mask only needed for last oc_block
239 if (oc_tail && k + 1 == jcp.nb_oc_blocking) {
240 if (is_padding)
241 vmovups(vmm | k_oc_tail_mask | T_z, vmm);
242 else
243 vmm = vmm | k_oc_tail_mask;
244 }
245 size_t aux_output_offset = get_output_offset(j, k);
246
247 vmovups(EVEX_compress_addr_safe(
248 reg_out, aux_output_offset, reg_out_long_offt),
249 vmm);
250 }
251}
252
253template <typename Vmm>
254void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(
255 int ur_w, int pad_l, int pad_r) {
256 const bool is_source_layout_nxc = is_src_layout_nxc();
257 const bool icb_loop_in_compute_function = is_source_layout_nxc;
258 const int ic_tail = jcp.ic_tail;
259 const int oc_tail = jcp.oc == jcp.oc_without_padding ? jcp.oc_tail : 0;
260 int iw = jcp.iw;
261 int kw = jcp.kw;
262 int ic_block = jcp.ic_block;
263 int oc_block = jcp.oc_block;
264 int nb_oc_block = jcp.nb_oc_blocking;
265 Label kh_label, kd_label;
266 std::vector<Label> ic_tail_jmp(kw);
267
268 // It seems that this compute_loop currently only handles one block of oc.
269 // assert if it is extended in future to catch unpadded_oc_tail.
270 assert(IMPLICATION(oc_tail, nb_oc_block == 1));
271
272 int num_ker_loads = ic_block * nb_oc_block * kw;
273 int ker_pipeline_depth
274 = oc_tail || ic_tail ? 1 : nstl::min(4, num_ker_loads);
275 assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
276 assert(oc_block >= ker_pipeline_depth);
277
278 int inp_mul = is_source_layout_nxc ? jcp.ngroups * jcp.ic
279 : (!jcp.is_1stconv ? ic_block : 1);
280
281 if (one_of(jcp.ndims, 3, 4)) {
282 mov(aux_reg_inp, reg_inp);
283 mov(aux_reg_ker, reg_ker);
284 }
285
286 if (jcp.ndims == 5) {
287 push(reg_out);
288
289 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
290 if (icb_loop_in_compute_function) {
291 // need to continue with the same kernel pointer, but as
292 // aux_reg_ker_d == reg_ker we need to save its value and restore
293 // it after kd loop
294 assert(aux_reg_ker_d == reg_ker);
295 push(aux_reg_ker_d);
296 } else
297 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
298 mov(aux_reg_inp_d, reg_inp);
299
300 L(kd_label);
301 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
302 } else {
303 mov(reg_kj, reg_kh);
304 }
305
306 if (jcp.ndims == 5) {
307 mov(aux_reg_inp, aux_reg_inp_d);
308 mov(aux_reg_ker, aux_reg_ker_d);
309 }
310
311 align(16);
312 L(kh_label);
313 {
314 int step = 0;
315 for (int ki = 0; ki < kw; ki++) {
316 for (int ic = 0; ic < ic_block; ic++) {
317 if (ic_tail && ic >= ic_tail) {
318 // if src has only tails to compute, skip early
319 if (jcp.ic == ic_tail)
320 break;
321 else if (ic == ic_tail) {
322 cmp(reg_channel, ic_tail);
323 je(ic_tail_jmp[ki], T_NEAR);
324 }
325 }
326 int aux_kernel_offset = 0;
327 if (step == 0) {
328 for (int i = 0; i < ker_pipeline_depth; i++) {
329 aux_kernel_offset = get_kernel_offset(ki, ic, 0, i);
330 vmovups(vmm_ker(i),
331 EVEX_compress_addr(
332 aux_reg_ker, aux_kernel_offset));
333 }
334 } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
335 int load_offset = ker_pipeline_depth - 1;
336 int ker_load_reg_idx
337 = (step + load_offset) % ker_pipeline_depth;
338 aux_kernel_offset
339 = get_kernel_offset(ki, ic, 0, load_offset);
340 vmovups(vmm_ker(ker_load_reg_idx),
341 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
342 }
343
344 Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth);
345 int j_start = get_ow_start(ki, pad_l);
346 int j_end = get_ow_end(ur_w, ki, pad_r);
347 for (int j = j_start; j < j_end; j++) {
348 size_t aux_input_offset
349 = get_input_offset(ki, ic, j, pad_l);
350 auto addr = EVEX_compress_addr_safe(
351 aux_reg_inp, aux_input_offset, reg_long_offt, true);
352 vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr);
353 }
354 step++;
355 }
356 L(ic_tail_jmp[ki]);
357 }
358 int ker_shift = jcp.typesize_in * kw * oc_block * ic_block;
359 add(aux_reg_ker, ker_shift);
360 int inp_shift = jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul;
361 add(aux_reg_inp, inp_shift);
362 dec(reg_kj);
363 cmp(reg_kj, 0);
364 jg(kh_label, T_NEAR);
365 }
366
367 if (jcp.ndims == 5) {
368 int inp_shift
369 = typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul;
370 add(aux_reg_inp_d, inp_shift);
371 int ker_shift
372 = typesize * jcp.kw * jcp.kh * jcp.oc_block * jcp.ic_block;
373 add(aux_reg_ker_d, ker_shift);
374
375 dec(reg_ki);
376 cmp(reg_ki, 0);
377 jg(kd_label, T_NEAR);
378
379 if (icb_loop_in_compute_function) pop(aux_reg_ker_d);
380 pop(reg_out);
381 }
382}
383
384template <typename Vmm>
385void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(
386 int ur_w, int pad_l, int pad_r) {
387 int kw = jcp.kw;
388 int ic_block = jcp.ic_block;
389 int oc_block = jcp.oc_block;
390 int nb_oc_block = jcp.nb_oc_blocking;
391 const bool is_source_layout_nxc = is_src_layout_nxc();
392 const bool icb_loop_in_compute_function = is_source_layout_nxc;
393 const int ic_tail = jcp.ic_tail;
394
395 Label kh_label, kd_label;
396 std::vector<Label> ic_tail_jmp(kw);
397 int shift_kernel_ptr
398 = jcp.typesize_in * jcp.kw * jcp.oc_block * jcp.ic_block;
399 int inp_mul = is_source_layout_nxc ? jcp.ngroups * jcp.ic
400 : (!jcp.is_1stconv ? ic_block : 1);
401
402 int shift_input_ptr
403 = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * inp_mul;
404
405 if (one_of(jcp.ndims, 3, 4)) {
406 mov(aux_reg_inp, reg_inp);
407 mov(aux_reg_ker, reg_ker);
408 }
409
410 if (jcp.ndims == 5) {
411 push(reg_out);
412
413 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
414 if (icb_loop_in_compute_function) {
415 // need to continue with the same kernel pointer, but as
416 // aux_reg_ker_d == reg_ker we need to save its value and restore
417 // it after kd loop
418 assert(aux_reg_ker_d == reg_ker);
419 push(aux_reg_ker_d);
420 } else
421 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
422
423 mov(aux_reg_inp_d, reg_inp);
424
425 L(kd_label);
426 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
427 } else {
428 mov(reg_kj, reg_kh);
429 }
430
431 if (jcp.ndims == 5) {
432 mov(aux_reg_inp, aux_reg_inp_d);
433 mov(aux_reg_ker, aux_reg_ker_d);
434 }
435
436 L(kh_label);
437 {
438 for (int ki = 0; ki < kw; ki++) {
439 int jj_start = get_ow_start(ki, pad_l);
440 int jj_end = get_ow_end(ur_w, ki, pad_r);
441 for (int ic = 0; ic < ic_block; ic++) {
442 if (ic_tail && ic >= ic_tail) {
443 // if src has only tails to compute, skip early
444 if (jcp.ic == ic_tail)
445 break;
446 else if (ic == ic_tail) {
447 cmp(reg_channel, ic_tail);
448 je(ic_tail_jmp[ki], T_NEAR);
449 }
450 }
451 if (jcp.kernel_kind == expl_bcast) {
452 for (int jj = jj_start; jj < jj_end; jj++) {
453 size_t aux_input_offset
454 = get_input_offset(ki, ic, jj, pad_l);
455 vbroadcastss(vmm_inp(jj, nb_oc_block),
456 EVEX_compress_addr_safe(aux_reg_inp,
457 aux_input_offset, reg_long_offt));
458 }
459 }
460 for (int ii = 0; ii < nb_oc_block; ii++) {
461 int aux_kernel_offset = jcp.typesize_in
462 * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd
463 * ic_block * oc_block
464 + ki * ic_block * oc_block + ic * oc_block);
465 if (jj_end - jj_start > 0)
466 vmovups(vmm_wei,
467 EVEX_compress_addr(
468 aux_reg_ker, aux_kernel_offset));
469 for (int jj = jj_start; jj < jj_end; jj++)
470 if (jcp.kernel_kind == expl_bcast)
471 vfmadd231ps(vmm_out(jj, ii),
472 vmm_inp(jj, nb_oc_block), vmm_wei);
473 else {
474 size_t aux_input_offset
475 = get_input_offset(ki, ic, jj, pad_l);
476 vfmadd231ps(vmm_out(jj, ii), vmm_wei,
477 EVEX_compress_addr_safe(aux_reg_inp,
478 aux_input_offset, reg_long_offt,
479 true));
480 }
481 }
482 }
483 L(ic_tail_jmp[ki]);
484 }
485 add(aux_reg_ker, shift_kernel_ptr);
486 add(aux_reg_inp, shift_input_ptr);
487 dec(reg_kj);
488 cmp(reg_kj, 0);
489 jg(kh_label, T_NEAR);
490 }
491
492 if (jcp.ndims == 5) {
493 add(aux_reg_inp_d,
494 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
495 const int ker_shift
496 = typesize * jcp.kw * jcp.kh * jcp.oc_block * jcp.ic_block;
497 add(aux_reg_ker_d, ker_shift);
498
499 dec(reg_ki);
500 cmp(reg_ki, 0);
501 jg(kd_label, T_NEAR);
502
503 if (icb_loop_in_compute_function) pop(aux_reg_ker_d);
504 pop(reg_out);
505 }
506}
507
508template <typename Vmm>
509void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(
510 int ur_w, int pad_l, int pad_r) {
511 if (jcp.ndims == 5) push(reg_oi);
512
513 prepare_output(ur_w);
514
515 Label skip_compute_loop;
516 if (jcp.ndims == 5) {
517 if ((jcp.dilate_d >= jcp.id)
518 || (jcp.kd - 1) * (jcp.dilate_d + 1)
519 < nstl::max(jcp.f_pad, jcp.back_pad)) {
520 mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
521 cmp(reg_kj, 0);
522 jle(skip_compute_loop, T_NEAR);
523 }
524 }
525 if ((jcp.dilate_h >= jcp.ih)
526 || (jcp.kh - 1) * (jcp.dilate_h + 1)
527 < nstl::max(jcp.t_pad, jcp.b_pad)) {
528 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
529 cmp(reg_kj, 0);
530 jle(skip_compute_loop, T_NEAR);
531 }
532
533 Label ic_loop;
534 const bool generate_icb_loop = jcp.nb_ic > 1 && is_src_layout_nxc();
535 if (generate_icb_loop) {
536 push(reg_inp);
537 push(reg_ker);
538
539 mov(reg_channel, ptr[param1 + GET_OFF(reduce_work)]);
540 L(ic_loop);
541 }
542
543 if (jcp.is_1stconv && jcp.kernel_kind != expl_bcast)
544 compute_loop_fma(ur_w, pad_l, pad_r);
545 else if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1)
546 compute_loop_fma(ur_w, pad_l, pad_r);
547 else
548 compute_loop_fma_core(ur_w, pad_l, pad_r);
549
550 if (generate_icb_loop) {
551 assert(is_src_layout_nxc());
552 const int inp_shift = jcp.ic_block * jcp.typesize_in;
553 add(reg_inp, inp_shift);
554 const size_t ker_shift = (size_t)jcp.kd * jcp.kh * jcp.kw * jcp.ic_block
555 * jcp.oc_block * jcp.typesize_in;
556 safe_add(reg_ker, ker_shift, reg_ker_long_offt);
557 sub(reg_channel, jcp.ic_block);
558 jg(ic_loop, T_NEAR);
559
560 pop(reg_ker);
561 pop(reg_inp);
562 }
563
564 L(skip_compute_loop);
565 store_output(ur_w);
566 if (jcp.ndims == 5) pop(reg_oi);
567}
568
569template <typename Vmm>
570void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate() {
571 int iw = jcp.iw;
572 int ow = jcp.ow;
573 int ow_block = jcp.ow_block;
574 int nb_ow = jcp.nb_ow;
575 int kw = jcp.kw;
576 int l_pad = jcp.l_pad;
577 int ur_w = jcp.ur_w;
578 int ur_w_tail = jcp.ur_w_tail;
579 int stride_w = jcp.stride_w;
580
581 int inp_mult = is_src_layout_nxc() ? jcp.ngroups * jcp.ic
582 : (jcp.is_1stconv ? 1 : jcp.ic_block);
583 int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult;
584 int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult;
585 int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult;
586 int out_shift = jcp.typesize_out * ur_w
587 * (is_dst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block);
588
589 preamble();
590 mov(reg_inp, ptr[param1 + GET_OFF(src)]);
591 mov(reg_out, ptr[param1 + GET_OFF(dst)]);
592 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
593 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
594
595 const int oc_tail = jcp.oc_tail;
596 if (oc_tail) {
597 Label done;
598 // dummy mask all 1's
599 kxnorw(k_oc_tail_mask, k_oc_tail_mask, k_oc_tail_mask);
600 mov(reg_load_work, ptr[param1 + GET_OFF(load_work)]);
601 cmp(reg_load_work, jcp.nb_oc_blocking * jcp.oc_block);
602 je(done, T_NEAR);
603 Reg32 reg_tail_32 = reg_tail.cvt32();
604 mov(reg_tail_32, (1 << oc_tail) - 1);
605 kmovw(k_oc_tail_mask, reg_tail_32);
606 L(done);
607 kmovw(postops_mask, k_oc_tail_mask);
608 } else if (jcp.with_binary)
609 if (jcp.oc_block != isa_simd_width_) {
610 const int mask = (1 << jcp.oc_block) - 1;
611 const Reg32 reg_tail_32 = reg_tail.cvt32();
612 mov(reg_tail_32, mask);
613 kmovw(postops_mask, reg_tail_32);
614 }
615
616 int r_pad = nstl::max(0, jcp.r_pad);
617 int n_oi = ow / ur_w;
618 int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w,
619 calculate_extended_filter_size(kw, jcp.dilate_w));
620
621 if (!is_ow_threading_on(jcp)) {
622 // ow is being processed as a whole - with left and right paddings
623 if (r_pad1 > 0) n_oi--;
624
625 if (ow == ur_w) {
626 compute_loop(ur_w, l_pad, r_pad);
627 } else {
628 if (n_oi == 0) {
629 compute_loop(ur_w, l_pad, r_pad1);
630 add(reg_inp, inp_shift_pad);
631 add(reg_out, out_shift);
632 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
633 } else {
634 xor_(reg_oi, reg_oi);
635 if (l_pad > 0) {
636 compute_loop(ur_w, l_pad, 0);
637 add(reg_inp, inp_shift_pad);
638 add(reg_out, out_shift);
639 inc(reg_oi);
640 }
641 if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
642 Label ow_loop_label;
643 L(ow_loop_label);
644 {
645 compute_loop(ur_w, 0, 0);
646 add(reg_inp, inp_shift);
647 add(reg_out, out_shift);
648 inc(reg_oi);
649 cmp(reg_oi, n_oi);
650 jl(ow_loop_label, T_NEAR);
651 }
652 }
653 if (r_pad1 > 0) {
654 compute_loop(ur_w, 0, r_pad1);
655 add(reg_inp, inp_shift);
656 add(reg_out, out_shift);
657 }
658 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
659 }
660 }
661 } else {
662 // ow block is only processed.
663 // Number of block is passed as parameter owb,
664 // and padding processing depends on this number.
665
666 Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
667 Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
668
669 assert(ow_block % ur_w == 0);
670 int n_oi_not_last_ow_block = ow_block / ur_w;
671 // to simplify code (and general regs usage),
672 // size of ow block must be >= 2 * ur_w
673 assert(n_oi_not_last_ow_block > 1);
674 int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
675 int n_oi_first_ow_block = n_oi_not_last_ow_block;
676
677 int n_oi_last_ow_block = (ow - ow_block * (nb_ow - 1)) / ur_w;
678
679 // prepare right padding
680 bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
681 bool first_ow_block_padded
682 = next_last_ow_block_padded && jcp.nb_ow == 2;
683 bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
684
685 if (last_ow_block_padded)
686 n_oi_last_ow_block--;
687 else if (first_ow_block_padded)
688 n_oi_first_ow_block--;
689 else if (next_last_ow_block_padded)
690 n_oi_next_last_ow_block--;
691
692 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
693 cmp(reg_owb, 0); // is that the first ow-block ?
694 jg(middle_ow_blocks_label, T_NEAR);
695
696 // the first ow block, compute left padding
697
698 mov(reg_oi, n_oi_first_ow_block);
699
700 if (l_pad > 0) {
701 compute_loop(ur_w, l_pad, 0);
702 add(reg_inp, inp_shift_pad);
703 add(reg_out, out_shift);
704 dec(reg_oi);
705 }
706 jmp(oi_loop_label, T_NEAR);
707
708 // middle or last ow block entry
709
710 L(middle_ow_blocks_label);
711
712 if (l_pad > 0) {
713 // just to consider left padding, not compute
714 add(reg_inp, inp_shift_pad_second_block);
715 }
716
717 // set number of iteration for oi-loop
718 cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
719 mov(reg_oi, n_oi_last_ow_block);
720 je(oi_loop_label, T_NEAR);
721 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
722 mov(reg_oi, n_oi_next_last_ow_block);
723 je(oi_loop_label, T_NEAR);
724 mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
725
726 // oi loop w/o padding
727 L(oi_loop_label);
728 L(oi_loop_start_label);
729 cmp(reg_oi, 0);
730 jle(oi_loop_end_label, T_NEAR);
731
732 compute_loop(ur_w, 0, 0);
733 add(reg_inp, inp_shift);
734 add(reg_out, out_shift);
735 dec(reg_oi);
736 jmp(oi_loop_start_label, T_NEAR);
737 L(oi_loop_end_label);
738
739 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
740
741 cmp(reg_owb, 0); // first ow-block ?
742 if (first_ow_block_padded) {
743 je(last_oi_label, T_NEAR);
744 } else {
745 je(end_label, T_NEAR);
746 }
747 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
748 jl(end_label, T_NEAR);
749 if (next_last_ow_block_padded) {
750 je(last_oi_label, T_NEAR);
751 } else {
752 je(end_label, T_NEAR);
753 }
754 // that is last block
755 if (!last_ow_block_padded) { jmp(tail_label, T_NEAR); }
756
757 // last oi block with right padding
758 L(last_oi_label);
759 compute_loop(ur_w, 0, r_pad1);
760 add(reg_inp, inp_shift);
761 add(reg_out, out_shift);
762
763 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
764 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
765 jl(end_label, T_NEAR);
766
767 L(tail_label);
768 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
769 L(end_label);
770 }
771 postamble();
772
773 if (jcp.with_eltwise) postops_injector_->prepare_table();
774}
775
776status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
777 const convolution_desc_t &cd, memory_desc_t &src_md,
778 memory_desc_t &weights_md, memory_desc_t &dst_md,
779 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
780 using namespace prop_kind;
781
782 if (!mayiuse(avx512_core)) return status::unimplemented;
783
784 const memory_desc_wrapper src_d(&src_md);
785 const memory_desc_wrapper weights_d(&weights_md);
786 const memory_desc_wrapper dst_d(&dst_md);
787 const memory_desc_wrapper bias_d(&bias_md);
788
789 if (!everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(),
790 dst_d.data_type()))
791 return status::unimplemented;
792
793 const int regs = 28;
794 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
795 int ndims = src_d.ndims();
796
797 jcp = zero<decltype(jcp)>();
798 jcp.nthr = jcp.aligned_threads = nthreads;
799 jcp.ndims = ndims;
800 jcp.prop_kind = cd.prop_kind;
801 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
802 jcp.mb = src_d.dims()[0];
803 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
804 jcp.oc_without_padding = jcp.oc;
805 jcp.ic = src_d.dims()[1] / jcp.ngroups;
806 jcp.ic_without_padding = jcp.ic;
807 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
808 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
809 jcp.iw = src_d.dims()[ndims - 1];
810 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
811 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
812 jcp.ow = dst_d.dims()[ndims - 1];
813 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
814 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
815 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
816 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
817 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
818 jcp.l_pad = cd.padding[0][ndims - 3];
819 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
820 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
821 jcp.stride_w = cd.strides[ndims - 3];
822
823 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
824 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
825 jcp.dilate_w = cd.dilates[ndims - 3];
826
827 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
828 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
829 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
830 jcp.r_pad = calculate_end_padding(
831 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
832 jcp.b_pad = calculate_end_padding(
833 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
834 jcp.back_pad = calculate_end_padding(
835 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
836 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
837 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
838 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
839 if (kernel_outside_src) return status::unimplemented;
840
841 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
842 const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
843 const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
844 const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
845 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
846 auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c,
847 dat_tag_nCx8c, dat_tag_nCx4c, dat_tag_ncx);
848 auto curr_dst_tag = dst_d.matches_one_of_tag(
849 dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
850 bool is_data_layout_nxc = IMPLICATION(curr_src_tag != dat_tag_nxc,
851 src_d.format_kind() == format_kind::any)
852 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
853 dst_d.format_kind() == format_kind::any)
854 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
855
856 jcp.is_1stconv = is_1stconv(jcp);
857
858 bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1
859 && src_d.data_type() == data_type::f32;
860
861 const int full_simd_w = cpu_isa_traits<avx512_core>::vlen / typesize;
862 jcp.simd_w = full_simd_w;
863 bool ok_to_try_lower_zmm = true
864 && IMPLICATION(is_data_layout_nxc,
865 jcp.oc < full_simd_w && jcp.ic < full_simd_w
866 && jcp.ngroups > 1)
867 && mayiuse(avx512_core) && src_d.data_type() == data_type::f32
868 && !jcp.is_1stconv && !ok_to_pad_channels
869 && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0);
870
871 if (ok_to_try_lower_zmm) {
872 for (auto simd : {8, 4}) {
873 if (jcp.ic % simd == 0 && jcp.oc % simd == 0) {
874 jcp.simd_w = simd;
875 break;
876 }
877 }
878 }
879
880 jcp.oc_block = jcp.simd_w;
881 jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
882
883 if (ok_to_pad_channels) {
884 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
885 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
886 }
887 if (!IMPLICATION(!is_data_layout_nxc,
888 jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0))
889 return status::unimplemented;
890
891 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0;
892 jcp.oc_tail = jcp.oc_without_padding % jcp.simd_w;
893
894 format_tag_t src_tag, dst_tag, wei_tag;
895
896 if (jcp.simd_w == 8) {
897 assert(with_groups);
898 src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
899 dst_tag = src_tag;
900 wei_tag = pick(ndims - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o);
901 } else if (jcp.simd_w == 4) {
902 assert(with_groups);
903 src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx4c;
904 dst_tag = src_tag;
905 wei_tag = pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o);
906 } else {
907 dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
908 src_tag = is_data_layout_nxc
909 ? dat_tag_nxc
910 : (jcp.is_1stconv ? dat_tag_ncx : dat_tag_nCx16c);
911 wei_tag = pick(2 * ndims - 6 + with_groups, OIw16i16o, gOIw16i16o,
912 OIhw16i16o, gOIhw16i16o, OIdhw16i16o, gOIdhw16i16o);
913 }
914
915 if (jcp.is_1stconv) {
916 wei_tag = with_groups
917 ? ((jcp.simd_w == 4)
918 ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o)
919 : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o))
920 : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
921 }
922
923 if (src_md.format_kind == format_kind::any)
924 CHECK(memory_desc_init_by_tag(src_md, src_tag));
925 else if (curr_src_tag != src_tag)
926 return status::unimplemented;
927 jcp.src_tag = src_tag;
928
929 if (dst_md.format_kind == format_kind::any)
930 CHECK(memory_desc_init_by_tag(dst_md, dst_tag));
931 else if (curr_dst_tag != dst_tag)
932 return status::unimplemented;
933 jcp.dst_tag = dst_tag;
934
935 if (init_tag(jcp.wei_tag, weights_md, weights_d, wei_tag)
936 != status::success)
937 return status::unimplemented;
938
939 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
940 if (jcp.with_bias) {
941 if (bias_d.format_kind() == format_kind::any)
942 CHECK(memory_desc_init_by_tag(bias_md, x));
943 }
944
945 CHECK(attr.set_default_formats(&dst_md));
946
947 const auto &post_ops = attr.post_ops_;
948 jcp.with_sum = post_ops.find(primitive_kind::sum) != -1;
949 const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
950 jcp.with_eltwise = eltwise_ind != -1;
951 if (jcp.with_eltwise) {
952 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
953 }
954 const int binary_ind = post_ops.find(primitive_kind::binary);
955 jcp.with_binary = binary_ind != -1;
956
957 jcp.post_ops = post_ops;
958
959 using namespace injector;
960 static constexpr bool sum_at_pos_0_only = true;
961 static constexpr bool sum_requires_scale_one = true;
962 static constexpr bool sum_requires_zp_zero = true;
963 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
964 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
965 sum_requires_zp_zero});
966 if (!post_ops_ok_) return status::unimplemented;
967
968 jcp.typesize_in = typesize;
969 jcp.typesize_out = typesize;
970
971 if (jcp.is_1stconv) {
972 jcp.ur_w = nstl::min(jcp.ow, regs);
973 } else {
974 // avx512_core guard - just to avoid possible regression for other archs
975 if (mayiuse(avx512_core)) {
976 jcp.ur_w = nstl::min(jcp.ow, regs);
977 } else {
978 for (int ur_w = regs; ur_w > 0; --ur_w) {
979 if (jcp.ow % ur_w == 0) {
980 jcp.ur_w = ur_w;
981 break;
982 }
983 }
984 }
985 if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) {
986 jcp.ur_w = nstl::min(jcp.ow, regs);
987 }
988 }
989 // TODO (Tanya): currently applied to Segnet convolutions only.
990 // Need to try for other topologies
991 if (jcp.ow > 150 && jcp.ur_w < regs / 2) jcp.ur_w = regs;
992
993 int n_oi = (jcp.ow / jcp.ur_w);
994 int r_pad = calculate_end_padding(
995 jcp.l_pad, jcp.ur_w * n_oi, jcp.iw, jcp.stride_w, ext_kw);
996 if (jcp.l_pad > 0 && r_pad > 0) n_oi--;
997
998 // Heuristic to optimize code size on KNX
999 bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0
1000 && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1));
1001 if (large_code_size) {
1002 const int max_code_size = 24 * 1024;
1003 const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw;
1004 int mult = 1;
1005 if (jcp.l_pad > 0) mult += 1;
1006 if (r_pad > 0) mult += 1;
1007 for (int ur_w = jcp.ur_w; ur_w > regs / 2; --ur_w) {
1008 if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) {
1009 jcp.ur_w = ur_w;
1010 break;
1011 }
1012 }
1013 }
1014
1015 /* Grouped channel offset to support 'non-blocked data' format for
1016 * convolution sizes with '(input_channel / ngroups) < simd' */
1017 jcp.nonblk_group_off
1018 = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw))
1019 ? jcp.ic
1020 : 1;
1021
1022 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
1023 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
1024 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1025
1026 auto is_ow_threading_applicable
1027 = [=]() { return (!jcp.is_1stconv && one_of(jcp.ndims, 3, 4)); };
1028
1029 jcp.ow_block = jcp.ow;
1030
1031 auto get_thr_eff = [=](int nb_oc_blocking, int ow_block, int nthr) {
1032 int nb_ow = div_up(jcp.ow, ow_block);
1033 int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking);
1034 int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow;
1035 float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block);
1036 float thr_eff
1037 = disbalance * (float)work_amount / rnd_up(work_amount, nthr);
1038 return thr_eff;
1039 };
1040
1041 auto get_ow_block = [=](int nb_oc_blocking, int ur_w, int nthr) {
1042 int res_ow_block = jcp.ow;
1043 float eff = get_thr_eff(nb_oc_blocking, res_ow_block, nthr);
1044 if (!is_ow_threading_applicable()) return res_ow_block;
1045
1046 int L2_part = (platform::get_per_core_cache_size(2) * 7 / 8) / typesize;
1047 int size_src_chunk = jcp.ic_block * ur_w * jcp.kh;
1048 int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w;
1049 int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block
1050 * jcp.kw * jcp.kh;
1051 int nurw_cache = (L2_part - 2 * size_wei_chunk)
1052 / (2 * size_dst_chunk + 2 * size_src_chunk);
1053 // current design of generate() requires ow_block >= 2 * ur_w
1054 int ow_block_cache = ur_w * nstl::max(2, nurw_cache);
1055
1056 int ow_block_thr = ow_block_cache;
1057 eff = get_thr_eff(nb_oc_blocking, ow_block_thr, nthr);
1058
1059 int max_nb_ow = div_up(jcp.ow, 2 * ur_w);
1060 int start_nb_ow = div_up(jcp.ow, ow_block_thr);
1061 for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) {
1062 int ow_block
1063 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
1064 float eff_threshold = 0.9f;
1065 if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold)
1066 break;
1067 if (div_up(jcp.ow, ow_block) != nb_ow) continue;
1068 float thr_eff = get_thr_eff(nb_oc_blocking, ow_block, nthr);
1069 float eff_step = 1.f;
1070 if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) {
1071 ow_block_thr = ow_block;
1072 eff = thr_eff;
1073 }
1074 eff_threshold = 0.98f;
1075 if (eff > eff_threshold) break;
1076 }
1077 res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr));
1078 eff = get_thr_eff(nb_oc_blocking, res_ow_block, nthr);
1079 return res_ow_block;
1080 };
1081
1082 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1083 if (mayiuse(avx512_core)) {
1084 int try_nb_oc_blocking = 2;
1085 unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w)
1086 * jcp.ic_block * jcp.kh * jcp.kd;
1087 unsigned int ker_out_size
1088 = typesize * jcp.ow * jcp.oc_block * try_nb_oc_blocking;
1089 unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
1090 * jcp.oc_block * try_nb_oc_blocking * jcp.kd;
1091 unsigned int ker_total_size
1092 = ker_inp_size + ker_out_size + ker_wei_size;
1093
1094 bool embd_bcast_condition_base = true
1095 && (jcp.kw == 3 && jcp.ow <= 28
1096 && ker_total_size < L1_cache_size)
1097 && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192)
1098 && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512);
1099 // These conditions define a set of shapes with 'ow = 1' which
1100 // have a very limited optimization space for performance. Try
1101 // to optimize by using a larger 'nb_oc_blocking' size.
1102 bool expl_bcast_condition
1103 = everyone_is(1, jcp.ngroups, jcp.mb, jcp.stride_h, jcp.ow,
1104 jcp.stride_w, jcp.id, jcp.od, jcp.kd, jcp.stride_d)
1105 && jcp.iw == jcp.kw && jcp.nb_oc > 1
1106 && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.dilate_w, jcp.f_pad,
1107 jcp.back_pad, jcp.dilate_d)
1108 && jcp.oh >= 60 && jcp.kh >= 3;
1109
1110 bool embd_bcast_condition = !expl_bcast_condition
1111 && (jcp.kw > 3
1112 || (jcp.stride_w == 1 && jcp.stride_h == 1
1113 && embd_bcast_condition_base)
1114 || ((jcp.stride_w != 1 || jcp.stride_h != 1)
1115 && ((jcp.mb <= 16
1116 && (jcp.oc <= 192 || jcp.oh <= 10)
1117 && embd_bcast_condition_base)))
1118 || (jcp.mb == 1
1119 && (jcp.ur_w >= jcp.ow || jcp.is_1stconv
1120 || (jcp.ow <= 147 && jcp.oc <= 96))));
1121
1122 if (jcp.mb == 1) {
1123 unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h)
1124 * div_up(jcp.iw, jcp.stride_w) * jcp.ic;
1125 unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw;
1126
1127 // Estimate whether we need to limit the number of threads
1128 // and calculate this number. Includes some heuristic.
1129 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1130 int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh;
1131 int job_size_min = work_amount / nthreads;
1132 int job_size_max = div_up(work_amount, nthreads);
1133 int ch_max = rnd_up(jcp.oh, job_size_max);
1134 int ch_min = (job_size_min == 0) ? jcp.oh
1135 : rnd_up(jcp.oh, job_size_min);
1136 bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2
1137 && (jcp.oh != 8 || ch_max / jcp.oh > 1);
1138 bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2
1139 && (jcp.oh != 8 || ch_min / jcp.oh > 1);
1140 bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1)
1141 || nthreads > oc_chunks;
1142 if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1
1143 && wei_size / inp_size > 24
1144 && (not_aligned_max || not_aligned_min) && eligible_case) {
1145 // Try to find number of threads > nthreads / 2 such that
1146 // oc_chunks is a multiple of nthreads, or nthreads is a
1147 // multiple of oc_chunks. Otherwise, keep default value.
1148 // TODO: implement a task-based alternative without throttling.
1149 jcp.aligned_threads = jcp.nthr;
1150 for (int i = jcp.nthr; i > jcp.nthr / 2; i--) {
1151 if (oc_chunks % i == 0 || i % oc_chunks == 0) {
1152 jcp.aligned_threads = i;
1153 break;
1154 }
1155 }
1156 }
1157 }
1158
1159 const int max_nb_oc = 5;
1160 if (embd_bcast_condition) {
1161 jcp.kernel_kind = embd_bcast;
1162 jcp.ur_w = nstl::min(jcp.ow, regs);
1163 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1164 const unsigned int L1_cache_size
1165 = platform::get_per_core_cache_size(1);
1166 if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3
1167 && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0
1168 && IMPLICATION(jcp.is_1stconv, jcp.mb == 1)
1169 && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) {
1170 jcp.nb_oc_blocking = try_nb_oc_blocking;
1171 jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
1172 }
1173 } else {
1174 jcp.kernel_kind = expl_bcast;
1175 jcp.nb_ic_blocking = 1;
1176 if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)
1177 || expl_bcast_condition) {
1178 float best_thr_eff = 0.f;
1179 int best_nb_oc_blocking = 1;
1180 for (int i = nstl::min(jcp.nb_oc, max_nb_oc); i > 0; i--) {
1181 if (jcp.nb_oc % i == 0) {
1182 if (expl_bcast_condition) {
1183 best_nb_oc_blocking = i;
1184 break;
1185 } else {
1186 int ur_w = nstl::min(jcp.ow, 31 / (i + 1));
1187 int ow_block = get_ow_block(i, ur_w, jcp.nthr);
1188 float thr_eff = get_thr_eff(i, ow_block, jcp.nthr);
1189 if (thr_eff > 1.05f * best_thr_eff) {
1190 best_nb_oc_blocking = i;
1191 best_thr_eff = thr_eff;
1192 }
1193 }
1194 }
1195 }
1196 jcp.nb_oc_blocking = best_nb_oc_blocking;
1197 jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
1198 }
1199 }
1200 }
1201
1202 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1203
1204 bool args_ok = true && jcp.l_pad <= jcp.ur_w
1205 && jcp.ic <= src_d.padded_dims()[1]
1206 && jcp.oc <= dst_d.padded_dims()[1]
1207 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
1208 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
1209 if (!args_ok) return status::unimplemented;
1210
1211 int r_pad_no_tail = nstl::max(0,
1212 calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
1213 jcp.stride_w, ext_kw));
1214 if (r_pad_no_tail > jcp.ur_w) return status::unimplemented;
1215
1216 pick_loop_order(jcp);
1217
1218 jcp.nb_ic_L2 = jcp.nb_ic;
1219
1220 jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, jcp.nthr);
1221 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1222 float thr_eff = get_thr_eff(jcp.nb_oc_blocking, jcp.ow_block, jcp.nthr);
1223
1224 /* adjust the thread decomposition
1225 * to improve the thr_eff for small size problem
1226 * the threshold L1_cache_size is empirical */
1227 size_t wei_size
1228 = (size_t)typesize * jcp.ic * jcp.oc * jcp.kh * jcp.kw * jcp.kd;
1229 size_t out_size = (size_t)jcp.mb * jcp.typesize_out * jcp.oc * jcp.oh
1230 * jcp.ow * jcp.od;
1231 size_t inp_size = (size_t)jcp.mb * jcp.typesize_in * jcp.ic * jcp.ih
1232 * jcp.iw * jcp.id;
1233 size_t total_size = jcp.ngroups * (wei_size + out_size + inp_size);
1234 float eff_threshold = 0.98f;
1235
1236 if (thr_eff < eff_threshold && jcp.ngroups < jcp.nthr
1237 && (total_size < L1_cache_size)) {
1238 int ow_block = jcp.ow_block;
1239 float best_thr_eff = -1.0f;
1240 float eff = -1.0f;
1241 int end_nthr = with_groups ? jcp.ngroups : 1;
1242 for (int nthr = jcp.nthr / 2; nthr >= end_nthr; nthr--) {
1243 ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, nthr);
1244 eff = get_thr_eff(jcp.nb_oc_blocking, ow_block, nthr);
1245 if (eff > 1.1f * best_thr_eff) {
1246 best_thr_eff = eff;
1247 jcp.ow_block = ow_block;
1248 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1249 jcp.nthr = jcp.aligned_threads = nthr;
1250 if (best_thr_eff > eff_threshold) break;
1251 }
1252 }
1253 }
1254
1255 const int L2_size = platform::get_per_core_cache_size(2) / typesize;
1256 // Source and output data needs to fit in L2,
1257 // leaving some space for weights and prefetching.
1258 int h_L2 = int(((0.6f * L2_size) / jcp.simd_w
1259 - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
1260 / (jcp.stride_h * jcp.iw + jcp.ow));
1261 jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
1262
1263 if (is_data_layout_nxc) {
1264 // TODO: improve L2 blocking for large IC
1265 const int nb_ic_theshold_L2 = 32;
1266 if (jcp.nb_ic > nb_ic_theshold_L2 && jcp.nb_ic < 2 * nb_ic_theshold_L2)
1267 jcp.nb_ic_L2 = div_up(jcp.nb_ic, 2);
1268 else
1269 jcp.nb_ic_L2 = nstl::min(nb_ic_theshold_L2, jcp.nb_ic);
1270 }
1271
1272 // A rough check on code size
1273 // TODO: come up with a tighter bound
1274 {
1275 const int max_code_size = 256 * 1024; // default size of jit generator
1276 int mult = 1 + (jcp.l_pad > 0) + (r_pad > 0);
1277 const float max_instruction_size = 15;
1278 float ur_fac
1279 = (float)jcp.kw * jcp.ic_block * jcp.nb_oc_blocking * jcp.ur_w;
1280 float code_size = mult * ur_fac * max_instruction_size;
1281 if (code_size > max_code_size) return status::unimplemented;
1282 }
1283
1284 return status::success;
1285}
1286
1287void jit_avx512_common_conv_fwd_kernel::init_scratchpad(
1288 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1289 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
1290 scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_out);
1291}
1292
1293template <typename Vmm>
1294void _jit_avx512_common_conv_bwd_data_kernel_f32<Vmm>::prepare_output(
1295 int ur_w) {
1296 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1297 for (int j = 0; j < ur_w; j++) {
1298 Vmm vmm = vmm_out(j, k);
1299 vpxord(vmm, vmm, vmm);
1300 }
1301 }
1302}
1303
1304template <typename Vmm>
1305void _jit_avx512_common_conv_bwd_data_kernel_f32<Vmm>::store_output(int ur_w) {
1306 Label no_update_label;
1307 const int ic_tail = jcp.ic_without_padding % jcp.simd_w;
1308 const bool dsrc_layout_nxc = is_dsrc_layout_nxc();
1309 mov(reg_channel, ptr[param + GET_OFF(channel)]);
1310 cmp(reg_channel, 0);
1311 je(no_update_label, T_NEAR);
1312 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1313 for (int j = 0; j < ur_w; j++) {
1314 Vmm vmm = vmm_out(j, k);
1315 size_t aux_src_offset = get_diff_src_offset(j, k);
1316 vaddps(vmm,
1317 EVEX_compress_addr_safe(
1318 reg_src, aux_src_offset, reg_long_offt));
1319 }
1320 }
1321
1322 L(no_update_label);
1323 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1324 for (int j = 0; j < ur_w; j++) {
1325 Vmm vmm = vmm_out(j, k);
1326 // mask only needed for last oc_block
1327 if (ic_tail && k + 1 == jcp.nb_ic_blocking && dsrc_layout_nxc)
1328 vmm = vmm | k_ic_tail_mask;
1329 size_t aux_src_offset = get_diff_src_offset(j, k);
1330 vmovups(EVEX_compress_addr_safe(
1331 reg_src, aux_src_offset, reg_long_offt),
1332 vmm);
1333 }
1334 }
1335}
1336
1337template <typename Vmm>
1338void _jit_avx512_common_conv_bwd_data_kernel_f32<Vmm>::compute_loop_fma(
1339 int ur_w, int l_overflow, int r_overflow) {
1340 Label kh_label, kd_label;
1341 int kw = jcp.kw;
1342 int ow = jcp.ow;
1343
1344 int ic_block = jcp.ic_block;
1345 int oc_block = jcp.oc_block;
1346 int stride_w = jcp.stride_w;
1347 int stride_h = jcp.stride_h;
1348
1349 int ker_pipeline_depth = 4;
1350 assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
1351 assert(oc_block >= ker_pipeline_depth);
1352
1353 int num_ker_loads = oc_block * kw;
1354 const bool ddst_layout_nxc = is_ddst_layout_nxc();
1355 int oc_mult = ddst_layout_nxc ? jcp.ngroups * jcp.oc : oc_block;
1356 const bool ocb_loop_in_compute_function = ddst_layout_nxc;
1357
1358 const int ic_tail = jcp.ic_tail;
1359 const int oc_tail = jcp.oc_tail;
1360 std::vector<Label> oc_tail_jmp(kw);
1361 if (ic_tail || oc_tail) ker_pipeline_depth = 1;
1362
1363 if (one_of(jcp.ndims, 3, 4)) {
1364 mov(aux_reg_dst, reg_dst);
1365 mov(aux_reg_ker, reg_ker);
1366 }
1367
1368 if (jcp.ndims == 5) {
1369 push(reg_src);
1370
1371 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1372 mov(aux_reg_dst_d, reg_dst);
1373 if (ocb_loop_in_compute_function) {
1374 // need to continue with the same kernel pointer, but as
1375 // aux_reg_ker_d == reg_ker we need to save its value and restore
1376 // it after kd loop
1377 assert(aux_reg_ker_d == reg_ker);
1378 push(aux_reg_ker_d);
1379 } else
1380 mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
1381
1382 L(kd_label);
1383 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
1384 } else {
1385 mov(reg_kj, reg_kh);
1386 }
1387
1388 if (jcp.ndims == 5) {
1389 mov(aux_reg_dst, aux_reg_dst_d);
1390 mov(aux_reg_ker, aux_reg_ker_d);
1391 }
1392
1393 L(kh_label);
1394 {
1395 int step = 0;
1396 for (int ki = 0; ki < kw; ki++) {
1397 for (int oc = 0; oc < oc_block; oc++) {
1398 if (oc_tail && oc >= oc_tail) {
1399 // if src has only tails to compute, skip early
1400 if (jcp.oc == oc_tail)
1401 break;
1402 else if (oc == oc_tail) {
1403 cmp(reg_channel, oc_tail);
1404 je(oc_tail_jmp[ki], T_NEAR);
1405 }
1406 }
1407 if (step == 0) {
1408 for (int i = 0; i < ker_pipeline_depth; i++) {
1409 int aux_kernel_offset = typesize
1410 * ((oc + i) * oc_block
1411 + ki * ic_block * oc_block);
1412 vmovups(vmm_ker(i),
1413 EVEX_compress_addr(
1414 aux_reg_ker, aux_kernel_offset));
1415 }
1416 } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
1417 int load_offset = ker_pipeline_depth - 1;
1418 int ker_load_reg_idx
1419 = (step + load_offset) % ker_pipeline_depth;
1420 int aux_kernel_offset = typesize
1421 * ((oc + load_offset) * oc_block
1422 + ki * ic_block * oc_block);
1423 vmovups(vmm_ker(ker_load_reg_idx),
1424 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
1425 }
1426
1427 auto vmm_kernel = vmm_ker(step % ker_pipeline_depth);
1428
1429 int jj_start = get_iw_start(ki, l_overflow);
1430 int jj_end = get_iw_end(ur_w, ki, r_overflow);
1431 const int dil_w = jcp.dilate_w + 1;
1432 const int ref_jj_start
1433 = nstl::max(0, l_overflow - (kw - 1 - ki) * dil_w);
1434 const int ref_jj_end
1435 = ur_w - nstl::max(0, r_overflow - ki * dil_w);
1436 assert(IMPLICATION(stride_w == 1,
1437 jj_start == ref_jj_start && jj_end == ref_jj_end));
1438 UNUSED(dil_w);
1439 UNUSED(ref_jj_start);
1440 UNUSED(ref_jj_end);
1441
1442 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
1443 assert((jj + jcp.l_pad - ki * (jcp.dilate_w + 1)) % stride_w
1444 == 0);
1445 int aux_dst_offset = get_dst_offset(jj, oc, ki);
1446 vfmadd231ps(vmm_out(jj, 0), vmm_kernel,
1447 EVEX_compress_addr(
1448 aux_reg_dst, aux_dst_offset, true));
1449 }
1450 step++;
1451 }
1452 L(oc_tail_jmp[ki]);
1453 }
1454
1455 const int ker_shift = typesize * stride_h * kw * oc_block * ic_block;
1456 add(aux_reg_ker, ker_shift);
1457 const int ddst_shift = typesize * (jcp.dilate_h + 1) * ow * oc_mult;
1458 sub(aux_reg_dst, ddst_shift);
1459
1460 dec(reg_kj);
1461 cmp(reg_kj, 0);
1462 jg(kh_label, T_NEAR);
1463 }
1464 if (jcp.ndims == 5) {
1465 const int depth_ddst_shift
1466 = typesize * (jcp.dilate_d + 1) * jcp.oh * ow * oc_mult;
1467 sub(aux_reg_dst_d, depth_ddst_shift);
1468 const int depth_ker_shift = typesize * jcp.stride_d * jcp.kw * jcp.kh
1469 * oc_block * ic_block;
1470 add(aux_reg_ker_d, depth_ker_shift);
1471
1472 dec(reg_ki);
1473 cmp(reg_ki, 0);
1474 jg(kd_label, T_NEAR);
1475 if (ocb_loop_in_compute_function) pop(aux_reg_ker_d);
1476 }
1477
1478 if (jcp.ndims == 5) { pop(reg_src); }
1479}
1480
1481template <typename Vmm>
1482void _jit_avx512_common_conv_bwd_data_kernel_f32<Vmm>::compute_loop_fma_core(
1483 int ur_w, int l_overflow, int r_overflow, int k_offset) {
1484 int kw = jcp.kw;
1485 int ow = jcp.ow;
1486 int stride_w = jcp.stride_w;
1487 int ic_block = jcp.ic_block;
1488 int oc_block = jcp.oc_block;
1489 int nb_ic_block = jcp.nb_ic_blocking;
1490 Label kh_label, kd_label;
1491
1492 const bool ddst_layout_nxc = is_ddst_layout_nxc();
1493 int shift_ker_ptr = typesize * kw * oc_block * ic_block;
1494 int oc_mult = ddst_layout_nxc ? jcp.ngroups * jcp.oc : oc_block;
1495 int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_mult;
1496
1497 const int oc_tail = jcp.oc_tail;
1498 const int max_filter_size = 20;
1499 Label oc_tail_jmp[max_filter_size];
1500
1501 auto kernel_offset = [=](int icb, int oc, int ki) {
1502 int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
1503 int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
1504 int oc_offset = oc * jcp.oc_block;
1505 return typesize * (blk_offset + oc_offset);
1506 };
1507
1508 if (one_of(jcp.ndims, 3, 4)) {
1509 mov(aux_reg_dst, reg_dst);
1510 mov(aux_reg_ker, reg_ker);
1511 }
1512
1513 const bool ocb_loop_in_compute_function = ddst_layout_nxc;
1514 if (jcp.ndims == 5) {
1515 push(reg_src);
1516
1517 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1518 mov(aux_reg_dst_d, reg_dst);
1519 if (ocb_loop_in_compute_function) {
1520 // need to continue with the same kernel pointer, but as
1521 // aux_reg_ker_d == reg_ker we need to save its value and restore
1522 // it after kd loop
1523 assert(aux_reg_ker_d == reg_ker);
1524 push(aux_reg_ker_d);
1525 } else
1526 mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
1527
1528 L(kd_label);
1529 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
1530 } else {
1531 mov(reg_kj, reg_kh);
1532 }
1533
1534 if (jcp.ndims == 5) {
1535 mov(aux_reg_dst, aux_reg_dst_d);
1536 mov(aux_reg_ker, aux_reg_ker_d);
1537 }
1538
1539 L(kh_label);
1540 {
1541 for (int ki = 0; ki < kw; ki++) {
1542 int jj_start = get_iw_start(ki, l_overflow);
1543 int jj_end = get_iw_end(ur_w, ki, r_overflow);
1544 for (int oc = 0; oc < oc_block; oc++) {
1545 if (oc_tail && oc >= oc_tail) {
1546 // if src has only tails to compute, skip early
1547 if (jcp.oc == oc_tail)
1548 break;
1549 else if (oc == oc_tail) {
1550 cmp(reg_channel, oc_tail);
1551 je(oc_tail_jmp[ki], T_NEAR);
1552 }
1553 }
1554 if (jcp.kernel_kind == expl_bcast) {
1555 for (int jj = jj_start; jj < jj_end; jj++) {
1556 int aux_output_offset = get_dst_offset(jj, oc, ki);
1557 vbroadcastss(vmm_inp(jj, nb_ic_block),
1558 ptr[aux_reg_dst + aux_output_offset]);
1559 }
1560 }
1561 for (int ii = 0; ii < nb_ic_block; ii++) {
1562 int aux_kernel_offset
1563 = kernel_offset(ii, oc, ki + k_offset);
1564 if (jj_end - jj_start > 0)
1565 vmovups(vmm_wei,
1566 EVEX_compress_addr(
1567 aux_reg_ker, aux_kernel_offset));
1568 for (int jj = jj_start; jj < jj_end; jj += stride_w)
1569 if (jcp.kernel_kind == expl_bcast)
1570 vfmadd231ps(vmm_out(jj, ii),
1571 vmm_inp(jj, nb_ic_block), vmm_wei);
1572 else
1573 vfmadd231ps(vmm_out(jj, ii), vmm_wei,
1574 EVEX_compress_addr(aux_reg_dst,
1575 get_dst_offset(jj, oc, ki), true));
1576 }
1577 }
1578 L(oc_tail_jmp[ki]);
1579 }
1580 add(aux_reg_ker, shift_ker_ptr);
1581 sub(aux_reg_dst, shift_dst_ptr);
1582 dec(reg_kj);
1583 cmp(reg_kj, 0);
1584 jg(kh_label, T_NEAR);
1585 }
1586
1587 if (jcp.ndims == 5) {
1588 sub(aux_reg_dst_d,
1589 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * oc_mult);
1590 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
1591
1592 dec(reg_ki);
1593 cmp(reg_ki, 0);
1594 jg(kd_label, T_NEAR);
1595
1596 if (ocb_loop_in_compute_function) pop(aux_reg_ker_d);
1597 pop(reg_src);
1598 }
1599}
1600
1601template <typename Vmm>
1602inline void _jit_avx512_common_conv_bwd_data_kernel_f32<Vmm>::compute_loop(
1603 int ur_w, int l_overflow, int r_overflow, int k_offset) {
1604 if (jcp.ndims == 5) push(reg_oi);
1605
1606 prepare_output(ur_w);
1607
1608 Label skip_compute_loop;
1609 if (jcp.ndims == 5) {
1610 mov(reg_kj, ptr[param + GET_OFF(kd_padding)]);
1611 cmp(reg_kj, 0);
1612 jle(skip_compute_loop, T_NEAR);
1613 }
1614 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
1615 cmp(reg_kj, 0);
1616 jle(skip_compute_loop, T_NEAR);
1617
1618 const bool generate_ocb_loop = jcp.nb_oc > 1 && is_ddst_layout_nxc();
1619 Label oc_loop;
1620 if (generate_ocb_loop) {
1621 push(reg_dst);
1622 push(reg_ker);
1623
1624 mov(reg_channel, ptr[param1 + GET_OFF(reduce_work)]);
1625 L(oc_loop);
1626 }
1627
1628 if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1)
1629 compute_loop_fma(ur_w, l_overflow, r_overflow);
1630 else
1631 compute_loop_fma_core(ur_w, l_overflow, r_overflow, k_offset);
1632
1633 if (generate_ocb_loop) {
1634 add(reg_dst, jcp.oc_block * typesize);
1635 const int ker_shift = jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw
1636 * jcp.ic_block * jcp.oc_block * typesize;
1637 add(reg_ker, ker_shift);
1638 sub(reg_channel, jcp.oc_block);
1639 jg(oc_loop, T_NEAR);
1640
1641 pop(reg_ker);
1642 pop(reg_dst);
1643 }
1644
1645 L(skip_compute_loop);
1646 store_output(ur_w);
1647 if (jcp.ndims == 5) pop(reg_oi);
1648}
1649
1650template <typename Vmm>
1651void _jit_avx512_common_conv_bwd_data_kernel_f32<Vmm>::generate() {
1652 int iw = jcp.iw;
1653 int kw = jcp.kw;
1654 int ur_w = jcp.ur_w;
1655 int ic_block = jcp.ic_block;
1656 int oc_block = jcp.oc_block;
1657 int nb_iw = jcp.nb_iw;
1658 int iw_block = jcp.iw_block;
1659 int ur_w_tail = jcp.ur_w_tail;
1660 int dilate_w = jcp.dilate_w + 1;
1661 int stride_w = jcp.stride_w;
1662
1663 int dst_shift = jcp.typesize_in * (ur_w / stride_w)
1664 * (is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : oc_block);
1665 int src_shift = jcp.typesize_out * ur_w
1666 * (is_dsrc_layout_nxc() ? jcp.ngroups * jcp.ic : ic_block);
1667
1668 preamble();
1669
1670 mov(reg_src, ptr[param + GET_OFF(src)]);
1671 mov(reg_dst, ptr[param + GET_OFF(dst)]);
1672 mov(reg_ker, ptr[param + GET_OFF(filt)]);
1673
1674 mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
1675
1676 const int ic_tail = jcp.ic_tail;
1677 if (ic_tail) {
1678 Label masking_done;
1679 // dummy mask all 1's
1680 kxnorw(k_ic_tail_mask, k_ic_tail_mask, k_ic_tail_mask);
1681 mov(reg_load_work, ptr[param1 + GET_OFF(load_work)]);
1682 cmp(reg_load_work, jcp.nb_ic_blocking * jcp.ic_block);
1683 je(masking_done, T_NEAR);
1684 Reg32 reg_tail_32 = reg_tail.cvt32();
1685 mov(reg_tail_32, (1 << ic_tail) - 1);
1686 kmovw(k_ic_tail_mask, reg_tail_32);
1687 L(masking_done);
1688 }
1689
1690 int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
1691 int r_overflow = nstl::max(
1692 0, ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad)) / stride_w);
1693 int r_overflow_no_tail = nstl::max(0,
1694 ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad + ur_w_tail))
1695 / stride_w);
1696
1697 int body_l_overflow = 0, body_r_overflow = 0;
1698 int n_oi = iw / ur_w;
1699 int head_n_oi = 0, body_n_oi = 0, pretail_n_oi = 0, tail_n_oi = 0;
1700 int head_thread = 0, pretail_thread = 0, tail_thread = 0;
1701 bool threaded = is_iw_threading_on(jcp);
1702 Label head_label, body_label, pretail_label, tail_label, end_label;
1703 assert(n_oi > 0);
1704 if (r_overflow_no_tail > 0) n_oi--;
1705 if (l_overflow > 0) n_oi--;
1706 if (n_oi < 0) {
1707 // l_overflow and r_overflow_no_tail are handled in the same compute_loop.
1708 // Perform one iteration of body handling l_overflow and r_overflow_no_tail.
1709 // TODO: Align other convolution kernels with this kernel. This version
1710 // now uses r_overflow_no_tail instead of r_overflow in compute loop, this was
1711 // done since when iw == ur_w, ur_w_tail == 0 and thus
1712 // r_overflow_no_tail seems more appropriate
1713 body_l_overflow = l_overflow;
1714 body_r_overflow = r_overflow_no_tail;
1715 n_oi = 1;
1716 l_overflow = 0;
1717 r_overflow_no_tail = 0;
1718 }
1719
1720 if (!threaded) {
1721 if (n_oi > 1) { mov(reg_oi, n_oi); }
1722 } else {
1723 // Setup for threaded code generation, and jump into the correct
1724 // portion of code for execution.
1725 head_thread = 0;
1726 tail_thread = nb_iw - 1;
1727 pretail_thread = tail_thread;
1728
1729 int base_n_oi = iw_block / ur_w;
1730 head_n_oi = l_overflow > 0 ? base_n_oi - 1 : base_n_oi;
1731 tail_n_oi = (iw - iw_block * (nb_iw - 1)) / ur_w;
1732 pretail_n_oi = tail_n_oi;
1733 if (r_overflow_no_tail > 0) {
1734 if (tail_n_oi > 0) {
1735 pretail_n_oi--;
1736 tail_n_oi = pretail_n_oi;
1737 } else {
1738 // pretail_thread and tail_thread are different
1739 pretail_n_oi = base_n_oi - 1;
1740 pretail_thread = tail_thread - 1;
1741 }
1742 if (head_thread == pretail_thread) {
1743 head_n_oi--;
1744 pretail_n_oi = 0;
1745 tail_n_oi = 0;
1746 }
1747 }
1748 body_n_oi = (head_thread < pretail_thread - 1) ? base_n_oi : 0;
1749
1750 // n_oi is used to determine how much control flow in the body portion
1751 // of the code needs generated. As such, n_oi needs to be set to the
1752 // maximum number of iterations it will be used the body code section.
1753 n_oi = nstl::max(body_n_oi, head_n_oi);
1754 n_oi = nstl::max(n_oi, pretail_n_oi);
1755
1756 assert(iw_block % ur_w == 0);
1757 mov(reg_iwb, ptr[param1 + GET_OFF(iwb)]);
1758
1759 if (head_n_oi != 0) mov(reg_oi, head_n_oi);
1760 cmp(reg_iwb, head_thread);
1761 je(head_label, T_NEAR);
1762
1763 cmp(reg_iwb, pretail_thread);
1764 if (pretail_n_oi == 0) {
1765 je(pretail_label, T_NEAR);
1766 } else {
1767 mov(reg_oi, pretail_n_oi);
1768 je(body_label, T_NEAR);
1769 }
1770 if (pretail_thread != tail_thread) {
1771 cmp(reg_iwb, tail_thread);
1772 je(tail_label, T_NEAR);
1773 }
1774 if (body_n_oi != 0) {
1775 mov(reg_oi, body_n_oi);
1776 jmp(body_label, T_NEAR);
1777 } else {
1778 jmp(end_label, T_NEAR);
1779 }
1780 }
1781 L(head_label);
1782 if (l_overflow > 0) {
1783 compute_loop(ur_w, l_overflow, 0);
1784 if (threaded && head_n_oi == 0 && head_thread != pretail_thread)
1785 jmp(end_label, T_NEAR);
1786 else {
1787 add(reg_src, src_shift);
1788 add(reg_dst, dst_shift);
1789 }
1790 }
1791 L(body_label);
1792 if (n_oi > 0) {
1793 Label ow_loop_label;
1794 L(ow_loop_label);
1795 {
1796 compute_loop(ur_w, body_l_overflow, body_r_overflow);
1797 if (n_oi > 1 || r_overflow_no_tail > 0 || ur_w_tail != 0) {
1798 add(reg_src, src_shift);
1799 if (!jcp.large_w_filter) { add(reg_dst, dst_shift); }
1800 }
1801 if (n_oi > 1) {
1802 sub(reg_oi, 1);
1803 jg(ow_loop_label, T_NEAR);
1804 }
1805 }
1806 }
1807 if (threaded) {
1808 mov(reg_iwb, ptr[param1 + GET_OFF(iwb)]);
1809 cmp(reg_iwb, pretail_thread);
1810 jne(end_label, T_NEAR);
1811 }
1812 L(pretail_label);
1813 if (r_overflow_no_tail > 0) {
1814 compute_loop(ur_w, 0, r_overflow_no_tail);
1815 if (ur_w_tail != 0) {
1816 if (threaded && tail_thread != pretail_thread)
1817 jmp(end_label, T_NEAR);
1818 add(reg_src, src_shift);
1819 add(reg_dst, dst_shift);
1820 }
1821 }
1822 L(tail_label);
1823 if (ur_w_tail != 0) {
1824 /* if 'filter-width > ur_w' then the main loop only partially computes
1825 * width, ur_w_tail needs to offset the initial ur_w from the filter
1826 * address. */
1827 if (jcp.large_w_filter)
1828 compute_loop(ur_w_tail, body_l_overflow, r_overflow - ur_w, ur_w);
1829 else
1830 compute_loop(ur_w_tail, 0, r_overflow);
1831 }
1832 L(end_label);
1833
1834 postamble();
1835}
1836
1837status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
1838 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
1839 memory_desc_t &diff_src_md, memory_desc_t &weights_md,
1840 memory_desc_t &diff_dst_md, int nthreads) {
1841 if (!mayiuse(avx512_core)) return status::unimplemented;
1842
1843 const memory_desc_wrapper diff_src_d(&diff_src_md);
1844 const memory_desc_wrapper weights_d(&weights_md);
1845 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
1846 jcp = zero<decltype(jcp)>();
1847
1848 if (!everyone_is(data_type::f32, diff_dst_d.data_type(),
1849 weights_d.data_type(), diff_src_d.data_type()))
1850 return status::unimplemented;
1851
1852 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
1853 int ndims = diff_src_d.ndims();
1854
1855 jcp.nthr = jcp.aligned_threads = nthreads;
1856 jcp.ndims = ndims;
1857 jcp.prop_kind = cd.prop_kind;
1858
1859 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1860 jcp.mb = diff_src_d.dims()[0];
1861
1862 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1863 jcp.oc_without_padding = jcp.oc;
1864 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
1865 jcp.ic_without_padding = jcp.ic;
1866
1867 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
1868 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims - 2];
1869 jcp.iw = diff_src_d.dims()[ndims - 1];
1870 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1871 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
1872 jcp.ow = diff_dst_d.dims()[ndims - 1];
1873
1874 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1875 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1876 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1877
1878 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1879 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1880 jcp.l_pad = cd.padding[0][ndims - 3];
1881
1882 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1883 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1884 jcp.stride_w = cd.strides[ndims - 3];
1885
1886 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1887 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1888 jcp.dilate_w = cd.dilates[ndims - 3];
1889 if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
1890 || (jcp.dilate_d != 0 && jcp.stride_d != 1)
1891 || (jcp.dilate_h != 0 && jcp.stride_h != 1))
1892 return status::unimplemented;
1893
1894 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1895 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1896 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1897 jcp.r_pad = calculate_end_padding(
1898 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
1899 jcp.b_pad = calculate_end_padding(
1900 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
1901 jcp.back_pad = calculate_end_padding(
1902 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
1903 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
1904 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
1905 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
1906 if (kernel_outside_src) return status::unimplemented;
1907
1908 jcp.aligned_threads = 0;
1909 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
1910 const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
1911 const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
1912 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1913 auto curr_src_tag = diff_src_d.matches_one_of_tag(
1914 dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
1915 auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
1916 dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
1917 bool is_data_layout_nxc
1918 = IMPLICATION(curr_src_tag != dat_tag_nxc,
1919 diff_src_d.format_kind() == format_kind::any)
1920 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
1921 diff_dst_d.format_kind() == format_kind::any)
1922 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
1923
1924 jcp.is_1stconv = false;
1925
1926 bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1
1927 && diff_src_d.data_type() == data_type::f32;
1928
1929 const int full_simd_w = cpu_isa_traits<avx512_core>::vlen / typesize;
1930 jcp.simd_w = full_simd_w;
1931 bool ok_to_try_lower_zmm = true
1932 && IMPLICATION(is_data_layout_nxc,
1933 jcp.ic < full_simd_w && jcp.oc < full_simd_w
1934 && jcp.ngroups > 1)
1935 && mayiuse(avx512_core) && diff_src_d.data_type() == data_type::f32
1936 && !jcp.is_1stconv
1937 && (jcp.oc % jcp.simd_w != 0 || jcp.ic % jcp.simd_w != 0)
1938 && !ok_to_pad_channels;
1939
1940 if (ok_to_try_lower_zmm) {
1941 for (auto simd : {8, 4}) {
1942 if (jcp.ic % simd == 0 && jcp.oc % simd == 0) {
1943 jcp.simd_w = simd;
1944 break;
1945 }
1946 }
1947 }
1948
1949 jcp.oc_block = jcp.simd_w;
1950 jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
1951
1952 if (ok_to_pad_channels) {
1953 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1954 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1955 }
1956
1957 if (!IMPLICATION(!is_data_layout_nxc,
1958 jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0))
1959 return status::unimplemented;
1960 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0;
1961 jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.simd_w : 0;
1962
1963 format_tag_t dat_tag, wei_tag;
1964 const auto nxc_tag = pick(ndims - 3, nwc, nhwc, ndhwc);
1965
1966 if (jcp.simd_w == 8) {
1967 assert(with_groups);
1968 dat_tag = is_data_layout_nxc ? nxc_tag
1969 : pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
1970 wei_tag = pick(ndims - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i);
1971 } else if (jcp.simd_w == 4) {
1972 assert(with_groups);
1973 dat_tag = is_data_layout_nxc ? nxc_tag
1974 : pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
1975 wei_tag = pick(ndims - 3, gOIw4o4i, gOIhw4o4i, gOIdhw4o4i);
1976 } else {
1977 dat_tag = is_data_layout_nxc
1978 ? pick(ndims - 3, nwc, nhwc, ndhwc)
1979 : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1980 wei_tag = pick(2 * ndims - 6 + with_groups, OIw16o16i, gOIw16o16i,
1981 OIhw16o16i, gOIhw16o16i, OIdhw16o16i, gOIdhw16o16i);
1982 }
1983
1984 if (diff_src_md.format_kind == format_kind::any) {
1985 CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag));
1986 } else if (curr_src_tag != dat_tag)
1987 return status::unimplemented;
1988 jcp.src_tag = dat_tag;
1989
1990 if (diff_dst_md.format_kind == format_kind::any) {
1991 CHECK(memory_desc_init_by_tag(diff_dst_md, dat_tag));
1992 } else if (curr_dst_tag != dat_tag)
1993 return status::unimplemented;
1994 jcp.dst_tag = dat_tag;
1995
1996 if (init_tag(jcp.wei_tag, weights_md, weights_d, wei_tag)
1997 != status::success)
1998 return status::unimplemented;
1999
2000 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
2001 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
2002
2003 jcp.ur_w = jcp.stride_w;
2004
2005 int regs = 28;
2006 if (jcp.iw <= regs)
2007 jcp.ur_w = jcp.iw;
2008 else {
2009 for (int ur_w = regs; ur_w > 0; --ur_w)
2010 if (ur_w % jcp.stride_w == 0) {
2011 jcp.ur_w = ur_w;
2012 break;
2013 }
2014 }
2015 int l_overflow = nstl::max(
2016 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
2017 int r_overflow_no_tail = nstl::max(0,
2018 ((jcp.kw - 1) * (jcp.dilate_w + 1)
2019 - nstl::max(0, jcp.r_pad + jcp.iw % jcp.ur_w))
2020 / jcp.stride_w);
2021 int n_oi = jcp.iw / jcp.ur_w;
2022 if (r_overflow_no_tail > 0) n_oi--;
2023
2024 jcp.typesize_in = typesize;
2025 jcp.typesize_out = typesize;
2026
2027 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
2028
2029 // Heuristic to optimize code size on KNX
2030 bool large_code_size = (jcp.ur_w != jcp.ow)
2031 && ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1))
2032 && (r_overflow_no_tail > 0) && (l_overflow > 0);
2033 if (large_code_size) {
2034 const int max_code_size = 24 * 1024;
2035 const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw;
2036 int mult = 1;
2037 if (l_overflow > 0) mult += 1;
2038 if (r_overflow_no_tail > 0) mult += 1;
2039 for (int ur_w = jcp.ur_w; ur_w > regs / 2; --ur_w) {
2040 if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2
2041 < max_code_size) {
2042 if (ur_w % jcp.stride_w == 0) {
2043 jcp.ur_w = ur_w;
2044 break;
2045 }
2046 }
2047 }
2048 }
2049
2050 /* Support for large filter 'kw > 14' is only possible when ur_w is small
2051 * (e.g ur_w = 1) because of register allocation (max_reg = 31) */
2052 const int min_filter_size = 14;
2053 /* Don't let JIT generate too big of a code which might result in an
2054 * out-of-memory crash. */
2055 const int max_filter_size = 20;
2056
2057 /* These conditions define a set of shapes with 'ow = 1' which
2058 * have a very limited optimization space for performance.
2059 * Optimize by using a targeted 'jcp.nb_ic_blocking' value. */
2060 jcp.large_w_filter = jcp.kw >= min_filter_size && jcp.kw < max_filter_size
2061 && jcp.ow == 1 && jcp.nb_ic > 1 && jcp.kw == jcp.iw
2062 && jcp.stride_w == 1
2063 && utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w);
2064
2065 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
2066 if (mayiuse(avx512_core)) {
2067 int try_nb_ic_blocking = 2;
2068 unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block
2069 * try_nb_ic_blocking * jcp.kh;
2070 unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block;
2071 unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
2072 * jcp.oc_block * try_nb_ic_blocking;
2073 unsigned int ker_total_size
2074 = ker_inp_size + ker_out_size + ker_wei_size;
2075 bool use_expl_bcast
2076 = !(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8)
2077 || (jcp.kw < 5
2078 && ((jcp.iw <= 5
2079 || (jcp.iw > 8 && jcp.iw <= 13))
2080 || ker_total_size > L1_cache_size)))
2081 || jcp.stride_h > 1 || jcp.stride_d > 1;
2082 if (use_expl_bcast && !jcp.large_w_filter) {
2083 jcp.kernel_kind = embd_bcast;
2084 jcp.ur_w = nstl::min(jcp.iw, regs);
2085 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
2086 if (!(jcp.kw > 3
2087 || (jcp.kw == 3 && ker_total_size < L1_cache_size
2088 && jcp.ow > 8))
2089 && jcp.stride_h == 1 && jcp.stride_d == 1)
2090 if (jcp.nb_ic % try_nb_ic_blocking == 0) {
2091 jcp.nb_ic_blocking = try_nb_ic_blocking;
2092 jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2093 if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2094 }
2095 } else {
2096 jcp.kernel_kind = expl_bcast;
2097 jcp.nb_oc_blocking = 1;
2098 jcp.nb_ic_blocking = jcp.large_w_filter ? 2 : 4;
2099 if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic;
2100 if (jcp.nb_ic % jcp.nb_ic_blocking != 0)
2101 for (int i = jcp.nb_ic_blocking; i > 0; i--)
2102 if (jcp.nb_ic % i == 0) {
2103 jcp.nb_ic_blocking = i;
2104 break;
2105 }
2106 jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2107 if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2108 }
2109 }
2110 jcp.ur_w_tail = jcp.iw % jcp.ur_w;
2111
2112 auto is_iw_threading_applicable = [=]() { return one_of(jcp.ndims, 3, 4); };
2113
2114 auto get_thr_eff = [=](int nb_ic_blocking, int iw_block, int nthr) {
2115 // Cost heuristic for threading overhead. Determined using OMP.
2116 const float iw_block_cost = 32.0;
2117
2118 int nb_iw = div_up(jcp.iw, iw_block);
2119 int nb_ic_chunks = div_up(jcp.nb_ic, nb_ic_blocking);
2120 int work_amount = jcp.mb * jcp.ih * nb_ic_chunks * nb_iw;
2121 float disbalance = (float)jcp.iw / rnd_up(jcp.iw, iw_block);
2122 float block_overhead = nstl::max(0.0f, 1.0f - iw_block_cost / iw_block);
2123 float thr_eff = block_overhead * disbalance
2124 * ((float)work_amount / rnd_up(work_amount, nthr));
2125 return thr_eff;
2126 };
2127
2128 auto get_iw_block = [=](int nb_ic_blocking, int ur_w, float &eff,
2129 int nthr) {
2130 int res_iw_block = jcp.iw;
2131 if (!is_iw_threading_applicable()) return res_iw_block;
2132
2133 int max_nb_iw = div_up(jcp.iw, 2 * ur_w);
2134 int iw_block_thr;
2135
2136 if (jcp.ndims == 3) {
2137 // Blocking optimization to prevent data from leaving cache This
2138 // blocking optimization does not handle height blocking, so it does
2139 // not apply to higher dimensions.
2140 // TODO: Implement a more general optimization taking into account
2141 // the height dimension.
2142 int L2_part
2143 = (platform::get_per_core_cache_size(2) * 7 / 8) / typesize;
2144 int size_diff_src_chunk = jcp.ic_block * nb_ic_blocking * ur_w;
2145 int size_diff_dst_chunk = jcp.oc_block * ur_w;
2146 int size_wei_chunk
2147 = jcp.ic_block * nb_ic_blocking * jcp.oc_block * jcp.kw;
2148 int nurw_cache = (L2_part - 2 * size_wei_chunk)
2149 / (2 * size_diff_dst_chunk + 2 * size_diff_src_chunk);
2150 // current design of generate() requires iw_block >= 2 * ur_w
2151 int iw_block_cache = ur_w * nstl::max(2, nurw_cache);
2152
2153 iw_block_thr = iw_block_cache;
2154 } else
2155 iw_block_thr = jcp.iw;
2156 eff = get_thr_eff(nb_ic_blocking, iw_block_thr, nthr);
2157
2158 // Search for most efficient threading over iw_blocks.
2159 int start_nb_iw = div_up(jcp.iw, iw_block_thr);
2160 for (int nb_iw = start_nb_iw; nb_iw <= max_nb_iw; nb_iw++) {
2161 float eff_threshold = 0.98f;
2162 if (eff > eff_threshold) break;
2163 int iw_block
2164 = nstl::min(rnd_up(div_up(jcp.iw, nb_iw), ur_w), jcp.iw);
2165 if (div_up(jcp.iw, iw_block) != nb_iw) continue;
2166 float thr_eff = get_thr_eff(nb_ic_blocking, iw_block, nthr);
2167 if (iw_block >= 2 * ur_w && thr_eff > eff) {
2168 iw_block_thr = iw_block;
2169 eff = thr_eff;
2170 }
2171 }
2172 res_iw_block = nstl::min(jcp.iw, nstl::max(2 * ur_w, iw_block_thr));
2173 return res_iw_block;
2174 };
2175
2176 float thr_eff = -1.0f;
2177 jcp.iw_block
2178 = get_iw_block(jcp.nb_ic_blocking, jcp.ur_w, thr_eff, jcp.nthr);
2179 jcp.nb_iw = div_up(jcp.iw, jcp.iw_block);
2180
2181 /* adjust the thread decomposition
2182 * to improve the thr_eff for small size problem
2183 * the threshold L1_cache_size is empirical */
2184 size_t wei_size
2185 = (size_t)typesize * jcp.ic * jcp.oc * jcp.kh * jcp.kw * jcp.kd;
2186 size_t out_size = (size_t)jcp.mb * jcp.typesize_out * jcp.oc * jcp.oh
2187 * jcp.ow * jcp.od;
2188 size_t inp_size = (size_t)jcp.mb * jcp.typesize_in * jcp.ic * jcp.ih
2189 * jcp.iw * jcp.id;
2190 size_t total_size = jcp.ngroups * (wei_size + out_size + inp_size);
2191
2192 if (jcp.ngroups < jcp.nthr && (total_size < L1_cache_size)) {
2193 int iw_block = jcp.iw_block;
2194 int end_nthr = with_groups ? jcp.ngroups : ndims - 2;
2195 float eff = -1.0f;
2196 float best_thr_eff = -1.0f;
2197 // When thr_eff equals zero (cannot get the proper effciency)
2198 // simply set the thread as 4 now
2199 // And update the eff inside get_iw_block to avoid redundant
2200 // computation when thr_eff != 0 but current eff == 0
2201 if (thr_eff == 0.f) {
2202 jcp.nthr = nstl::min(jcp.nthr, 4);
2203 } else {
2204 for (int nthr = jcp.nthr / 2; nthr >= end_nthr; nthr--) {
2205 iw_block
2206 = get_iw_block(jcp.nb_ic_blocking, jcp.ur_w, eff, nthr);
2207 if (eff > 1.1f * best_thr_eff) {
2208 best_thr_eff = eff;
2209 jcp.iw_block = iw_block;
2210 jcp.nb_iw = div_up(jcp.iw, jcp.iw_block);
2211 jcp.nthr = jcp.aligned_threads = nthr;
2212 if (best_thr_eff > 0.98f) break;
2213 }
2214 }
2215 }
2216 }
2217
2218 if (l_overflow * jcp.stride_w > jcp.ur_w && !jcp.large_w_filter)
2219 return status::unimplemented;
2220
2221 r_overflow_no_tail = nstl::max(0,
2222 ((jcp.kw - 1) * (jcp.dilate_w + 1)
2223 - nstl::max(0, jcp.r_pad + jcp.ur_w_tail))
2224 / jcp.stride_w);
2225 bool tails_not_ok = false
2226 /* maximum 1 ur_w block with r_overflow so far */
2227 || r_overflow_no_tail * jcp.stride_w > jcp.ur_w
2228 /* ur_w must be a multiple of stride */
2229 || ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
2230 /* r_pad must not extend beyond ur_w_tail */
2231 || ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0));
2232 if (tails_not_ok) return status::unimplemented;
2233
2234 pick_loop_order(jcp);
2235
2236 jcp.nb_oc_L2 = jcp.nb_oc;
2237 if (is_data_layout_nxc) {
2238 // TODO: improve L2 blocking for large OC
2239 const int nb_oc_theshold_L2 = 32;
2240 if (jcp.nb_oc > nb_oc_theshold_L2 && jcp.nb_oc < 2 * nb_oc_theshold_L2)
2241 jcp.nb_oc_L2 = div_up(jcp.nb_oc, 2);
2242 else
2243 jcp.nb_oc_L2 = nstl::min(nb_oc_theshold_L2, jcp.nb_oc);
2244 }
2245
2246 bool args_ok = true && jcp.ic <= diff_src_d.padded_dims()[1]
2247 && jcp.oc <= diff_dst_d.padded_dims()[1]
2248 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
2249 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
2250 if (!args_ok) return status::unimplemented;
2251
2252 // A rough check on code size
2253 // TODO: come up with a tighter bound
2254 {
2255 const int max_code_size = 256 * 1024; // default size of jit generator
2256 int mult = 1 + (l_overflow > 0) + (r_overflow_no_tail > 0);
2257 const float max_instruction_size = 15;
2258 float ur_fac
2259 = (float)jcp.kw * jcp.oc_block * jcp.nb_ic_blocking * jcp.ur_w;
2260 float code_size = mult * ur_fac * max_instruction_size;
2261 if (code_size > max_code_size && !jcp.large_w_filter)
2262 return status::unimplemented;
2263 }
2264
2265 return status::success;
2266}
2267
2268void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
2269 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
2270 UNUSED(scratchpad);
2271 UNUSED(jcp);
2272}
2273
2274// Initialize static data members
2275const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28;
2276const int jit_avx512_common_conv_bwd_weights_kernel_f32::min_oh_reduce = 9;
2277
2278void jit_avx512_common_conv_bwd_weights_kernel_f32::
2279 od_step_comeback_pointers() {
2280 Label kd_comeback_label;
2281
2282 /* 'depth' loop count bound by 'kd_work_size' */
2283 mov(kj, reg_kd_count);
2284 L(kd_comeback_label);
2285 {
2286 int inp_mult = is_src_layout_nxc()
2287 ? jcp.ngroups * jcp.ic
2288 : (jcp.is_1stconv ? 1 : jcp.ic_block);
2289 int iw = jcp.iw;
2290 sub(reg_input,
2291 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult);
2292 sub(reg_kernel,
2293 jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block
2294 * jcp.oc_block);
2295 dec(kj);
2296 cmp(kj, 0);
2297 jg(kd_comeback_label, T_NEAR);
2298 }
2299}
2300
2301void jit_avx512_common_conv_bwd_weights_kernel_f32::
2302 oh_step_comeback_pointers() {
2303 Label kh_comeback_label, kd_comeback_label;
2304 mov(kj, reg_kh);
2305 L(kh_comeback_label);
2306 {
2307 int kw = jcp.is_hw_transp ? 1 : jcp.kw;
2308 int inp_mult = is_src_layout_nxc()
2309 ? jcp.ngroups * jcp.ic
2310 : (jcp.is_1stconv ? 1 : jcp.ic_block);
2311 int iw = jcp.is_hw_transp ? 1 : jcp.iw;
2312 sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult);
2313 sub(reg_kernel, jcp.typesize_out * kw * jcp.ic_block * jcp.oc_block);
2314 dec(kj);
2315 cmp(kj, 0);
2316 jg(kh_comeback_label, T_NEAR);
2317 }
2318}
2319
2320void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma(
2321 int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset,
2322 int kernel_offset, int output_offset, bool input_wraparound) {
2323
2324 int kw = jcp.is_hw_transp ? jcp.tr_kw : jcp.kw;
2325 int iw = jcp.is_hw_transp ? jcp.tr_iw : jcp.iw;
2326 int kw_tr_mult = jcp.is_hw_transp ? jcp.kw : 1;
2327 int ic_block = jcp.ic_block;
2328 int oc_block = jcp.oc_block;
2329 auto get_ker_offt = [=](int i_kw, int i_ic) {
2330 return typesize * (i_kw * kw_tr_mult * ic_block + i_ic) * jcp.oc_block
2331 + kernel_offset;
2332 };
2333 for (int i_kw = 0; i_kw < kw; i_kw++)
2334 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2335 vmovups(Zmm(i_kw * ic_block_step + i_ic),
2336 EVEX_compress_addr(reg_kernel, get_ker_offt(i_kw, i_ic)));
2337 const int out_mult = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : oc_block;
2338 const int oc_tail = jcp.oc_tail;
2339
2340 for (int i_ur = 0; i_ur < ur_w; i_ur++) {
2341 const int ddst_pipeline_start_idx = ic_block_step * kw;
2342 const int ddst_pipeline_len = 4;
2343 auto get_ddst_reg_idx = [=](int ur_idx) {
2344 return ddst_pipeline_start_idx + (ur_idx) % ddst_pipeline_len;
2345 };
2346 auto get_ddst_offt = [=](int ur_idx) {
2347 return typesize * ur_idx * out_mult + output_offset;
2348 };
2349
2350 if (i_ur == 0) {
2351 for (int i = 0; i < nstl::min(ddst_pipeline_len, ur_w); i++) {
2352 int ur_idx = i_ur + i;
2353 auto zmm_ddst = Zmm(get_ddst_reg_idx(ur_idx));
2354 if (oc_tail) zmm_ddst = zmm_ddst | k_oc_mask | T_z;
2355 vmovups(zmm_ddst,
2356 EVEX_compress_addr(reg_output, get_ddst_offt(ur_idx)));
2357 }
2358 } else if (i_ur + ddst_pipeline_len - 1 < ur_w) {
2359
2360 int ur_idx = i_ur + ddst_pipeline_len - 1;
2361
2362 auto zmm_ddst = Zmm(get_ddst_reg_idx(ur_idx));
2363 if (oc_tail) zmm_ddst = zmm_ddst | k_oc_mask | T_z;
2364 vmovups(zmm_ddst,
2365 EVEX_compress_addr(reg_output, get_ddst_offt(ur_idx)));
2366 }
2367
2368 for (int i_kw = 0; i_kw < kw; i_kw++) {
2369 int i_iw = get_iw_idx(i_ur, i_kw, pad_l);
2370 if (i_iw < 0 || i_iw > get_iw_idx(ur_w - 1, kw - 1, pad_l) - pad_r
2371 || get_iw_idx(i_ur, i_kw, jcp.l_pad) >= iw)
2372 continue;
2373 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2374 vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic),
2375 Zmm(get_ddst_reg_idx(i_ur)),
2376 EVEX_compress_addr_safe(reg_input,
2377 get_full_src_offset(i_iw, i_ic, input_offset),
2378 reg_long_offt, true));
2379 }
2380 }
2381 }
2382
2383 for (int i_kw = 0; i_kw < kw; i_kw++)
2384 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2385 vmovups(EVEX_compress_addr(reg_kernel, get_ker_offt(i_kw, i_ic)),
2386 Zmm(i_kw * ic_block_step + i_ic));
2387}
2388
2389void jit_avx512_common_conv_bwd_weights_kernel_f32::
2390 compute_ic_block_step_fma_expl(int ur_w, int pad_l, int pad_r,
2391 int ic_block_step, int input_offset, int kernel_offset,
2392 int output_offset, bool input_wraparound) {
2393 int kw = jcp.kw;
2394 int ic_block = jcp.ic_block;
2395 int oc_block = jcp.oc_block;
2396 const int oc_tail = jcp.oc_tail;
2397 const bool ddst_layout_nxc = is_ddst_layout_nxc();
2398 const int max_regs = 32;
2399 const int ddst_pipeline_start_idx = 2 * ic_block_step * kw;
2400 const int ddst_pipeline_len
2401 = ddst_layout_nxc ? 1 : max_regs - ddst_pipeline_start_idx;
2402 const int iw_last_value = get_iw_idx(ur_w - 1, kw - 1, pad_l) - pad_r;
2403 assert(jcp.stride_w == 1 && jcp.dilate_w == 0 && ddst_pipeline_len > 0
2404 && jcp.kernel_kind == expl_bcast);
2405
2406 const int out_mult = ddst_layout_nxc ? jcp.ngroups * jcp.oc : oc_block;
2407 auto get_diff_wei_reg_idx
2408 = [=](int i_kw, int i_ic) { return i_kw * ic_block_step + i_ic; };
2409 auto get_src_reg_idx = [=](int i_iw, int i_ic) {
2410 return kw * ic_block_step + ((i_iw + pad_l) % kw) * ic_block_step
2411 + i_ic;
2412 };
2413 auto get_diff_dst_reg_idx = [=](int i_ur) {
2414 return ddst_pipeline_start_idx + i_ur % ddst_pipeline_len;
2415 };
2416
2417 for (int i_kw = 0; i_kw < kw; i_kw++)
2418 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2419 auto zmm_ker = Zmm(get_diff_wei_reg_idx(i_kw, i_ic));
2420 vpxord(zmm_ker, zmm_ker, zmm_ker);
2421 }
2422
2423 for (int i_ur = 0; i_ur < ur_w; i_ur++) {
2424 if (i_ur == 0) {
2425 for (int i = 0; i < nstl::min(ddst_pipeline_len, ur_w); i++) {
2426 auto addr_out = EVEX_compress_addr(
2427 reg_output, typesize * i * out_mult + output_offset);
2428 auto zmm_ddst = Zmm(get_diff_dst_reg_idx(i));
2429 if (oc_tail) zmm_ddst = zmm_ddst | k_oc_mask | T_z;
2430 vmovups(zmm_ddst, addr_out);
2431 }
2432
2433 for (int i_kw = 0; i_kw < kw; i_kw++) {
2434 int i_iw = get_iw_idx(0, i_kw, pad_l);
2435 if (i_iw < 0 || i_iw > iw_last_value) continue;
2436
2437 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2438 auto addr_inp = EVEX_compress_addr_safe(reg_input,
2439 get_full_src_offset(i_iw, i_ic, input_offset),
2440 reg_long_offt);
2441 vbroadcastss(Zmm(get_src_reg_idx(i_iw, i_ic)), addr_inp);
2442 }
2443 }
2444 } else {
2445 int diff_dst_load_idx = i_ur + ddst_pipeline_len - 1;
2446 if (diff_dst_load_idx < ur_w) {
2447 auto addr_out = EVEX_compress_addr(reg_output,
2448 typesize * diff_dst_load_idx * out_mult
2449 + output_offset);
2450 auto zmm_ddst = Zmm(get_diff_dst_reg_idx(diff_dst_load_idx));
2451 if (oc_tail) zmm_ddst = zmm_ddst | k_oc_mask | T_z;
2452 vmovups(zmm_ddst, addr_out);
2453 }
2454
2455 int i_iw = get_iw_idx(i_ur, kw - 1, pad_l);
2456 if (i_iw >= 0 && i_iw <= iw_last_value) {
2457 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2458 auto addr_inp = EVEX_compress_addr_safe(reg_input,
2459 get_full_src_offset(i_iw, i_ic, input_offset),
2460 reg_long_offt);
2461 vbroadcastss(Zmm(get_src_reg_idx(i_iw, i_ic)), addr_inp);
2462 }
2463 }
2464 }
2465 for (int i_kw = 0; i_kw < kw; i_kw++) {
2466 int i_iw = get_iw_idx(i_ur, i_kw, pad_l);
2467 if (i_iw < 0 || i_iw > iw_last_value) continue;
2468 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2469 vfmadd231ps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2470 Zmm(get_diff_dst_reg_idx(i_ur)),
2471 Zmm(get_src_reg_idx(i_iw, i_ic)));
2472 }
2473 }
2474 }
2475
2476 for (int i_kw = 0; i_kw < kw; i_kw++)
2477 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2478 auto addr_ker = EVEX_compress_addr(reg_kernel,
2479 typesize * (i_kw * ic_block + i_ic) * jcp.oc_block
2480 + kernel_offset);
2481 vaddps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)), addr_ker);
2482 }
2483
2484 for (int i_kw = 0; i_kw < kw; i_kw++)
2485 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2486 auto addr_ker = EVEX_compress_addr(reg_kernel,
2487 typesize * (i_kw * ic_block + i_ic) * jcp.oc_block
2488 + kernel_offset);
2489 vmovups(addr_ker, Zmm(get_diff_wei_reg_idx(i_kw, i_ic)));
2490 }
2491}
2492
2493void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step(
2494 int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset,
2495 int kernel_offset, int output_offset, bool input_wraparound) {
2496 if (jcp.kernel_kind == expl_bcast)
2497 compute_ic_block_step_fma_expl(ur_w, pad_l, pad_r, ic_block_step,
2498 input_offset, kernel_offset, output_offset, input_wraparound);
2499 else
2500 compute_ic_block_step_fma(ur_w, pad_l, pad_r, ic_block_step,
2501 input_offset, kernel_offset, output_offset, input_wraparound);
2502}
2503
2504void jit_avx512_common_conv_bwd_weights_kernel_f32 ::
2505 compute_oh_step_unroll_ow_icblock(int ic_block_step, int max_ur_w) {
2506 UNUSED(max_ur_w);
2507
2508 Label kh_label, kd_label;
2509
2510 int ic_block = jcp.ic_block;
2511 int oc_block = jcp.oc_block;
2512 const bool src_layout_nxc = is_src_layout_nxc();
2513 int inp_mul = src_layout_nxc ? jcp.ngroups * jcp.ic
2514 : (!jcp.is_1stconv ? ic_block : 1);
2515 int iw = jcp.iw;
2516
2517 int r_pad = nstl::max(0, jcp.r_pad);
2518 int l_pad = jcp.l_pad;
2519
2520 if (jcp.ndims == 5) {
2521 L(kd_label);
2522 mov(reg_input, aux_reg_input);
2523 mov(reg_kernel, aux_reg_kernel);
2524 }
2525
2526 const int ic_tail = jcp.ic_tail;
2527 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
2528 mov(kj, reg_kh);
2529 L(kh_label);
2530 {
2531 Label icb_block_label, icb_block_label_cb, ic_tail_loop, ic_tail_label;
2532 if (generate_icb_loop || ic_tail) {
2533 push(reg_input);
2534 push(reg_kernel);
2535 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
2536 }
2537
2538 if (ic_tail) {
2539 cmp(reg_icb, ic_block);
2540 jl(ic_tail_loop, T_NEAR);
2541 }
2542
2543 const int ic_tail_loop_work = rnd_dn(ic_tail, ic_block_step);
2544 Label icb_block_label_end;
2545 L(icb_block_label);
2546 for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
2547 const int input_offset = jcp.typesize_in * i_b_ic;
2548 compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step,
2549 input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0,
2550 i_b_ic + ic_block_step >= jcp.ic_block);
2551 if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step);
2552 if (ic_tail && i_b_ic + ic_block_step == ic_tail_loop_work) {
2553 cmp(reg_icb, ic_block_step);
2554 jl(icb_block_label_end, T_NEAR);
2555 }
2556 }
2557 L(icb_block_label_end);
2558
2559 const int input_icb_shift = jcp.typesize_in * ic_block;
2560 const size_t kernel_icb_shift = (size_t)jcp.typesize_out * jcp.kd
2561 * jcp.kh * jcp.kw * ic_block * oc_block;
2562
2563 if (generate_icb_loop) {
2564 // icb loop supported for src in nxc layout only
2565 assert(src_layout_nxc);
2566 add(reg_input, input_icb_shift);
2567 safe_add(reg_kernel, kernel_icb_shift, reg_long_offt);
2568 cmp(reg_icb, ic_block);
2569 jge(icb_block_label, T_NEAR);
2570 }
2571
2572 if (ic_tail) {
2573 L(ic_tail_loop);
2574 Label skip_ic_tail;
2575 cmp(reg_icb, 0);
2576 jle(skip_ic_tail, T_NEAR);
2577 if (ic_tail_loop_work) {
2578 cmp(reg_icb, ic_tail_loop_work);
2579 jge(icb_block_label, T_NEAR);
2580 if (generate_icb_loop) {
2581 // compensate offset added in generate_icb_loop
2582 sub(reg_input, input_icb_shift);
2583 safe_sub(reg_kernel, kernel_icb_shift, reg_long_offt);
2584 }
2585 }
2586
2587 L(ic_tail_label);
2588 if (ic_tail % ic_block_step) {
2589 cmp(reg_icb, 0);
2590 jle(skip_ic_tail, T_NEAR);
2591 const int i_b_ic = ic_tail_loop_work;
2592 const int input_offset = jcp.typesize_in * i_b_ic;
2593 compute_ic_block_step(jcp.ur_w, l_pad, r_pad,
2594 ic_tail % ic_block_step, input_offset,
2595 jcp.typesize_out * i_b_ic * jcp.oc_block, 0);
2596 }
2597 L(skip_ic_tail);
2598 }
2599
2600 if (generate_icb_loop || ic_tail) {
2601 pop(reg_kernel);
2602 pop(reg_input);
2603 }
2604
2605 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
2606 add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
2607 dec(kj);
2608 cmp(kj, 0);
2609 jg(kh_label, T_NEAR);
2610 }
2611
2612 if (jcp.ndims == 5) {
2613 add(aux_reg_input,
2614 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul);
2615 add(aux_reg_kernel,
2616 jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block);
2617 dec(ki);
2618 cmp(ki, 0);
2619 jg(kd_label, T_NEAR);
2620 }
2621}
2622
2623void jit_avx512_common_conv_bwd_weights_kernel_f32 ::compute_oh_step_unroll_ow(
2624 int ic_block_step, int max_ur_w) {
2625 Label kh_label, ic_block_label, ic_tail_loop_label, ic_tail_label, kd_label;
2626 const bool src_layout_nxc = is_src_layout_nxc();
2627 int inp_mul = src_layout_nxc ? jcp.ngroups * jcp.ic
2628 : (!jcp.is_1stconv ? jcp.ic_block : 1);
2629 const int ic_tail = jcp.ic_tail;
2630 UNUSED(max_ur_w);
2631
2632 int ic_block = jcp.ic_block;
2633 int oc_block = jcp.oc_block;
2634
2635 int inp_icb_sp_stride = jcp.is_hw_transp ? 1 : jcp.iw;
2636 int ow = jcp.is_hw_transp ? jcp.oh : jcp.ow;
2637
2638 int r_pad = nstl::max(0, jcp.r_pad);
2639 int l_pad = jcp.l_pad;
2640
2641 if (jcp.ndims == 5) {
2642 L(kd_label);
2643 mov(reg_input, aux_reg_input);
2644 mov(reg_kernel, aux_reg_kernel);
2645 }
2646
2647 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
2648 mov(kj, reg_kh);
2649 L(kh_label);
2650 {
2651 Label icb_block_label;
2652 if (generate_icb_loop || ic_tail) {
2653 push(reg_input);
2654 push(reg_kernel);
2655 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
2656 }
2657
2658 if (ic_tail) {
2659 cmp(reg_icb, ic_block);
2660 jl(ic_tail_loop_label, T_NEAR);
2661 }
2662
2663 L(icb_block_label);
2664 Label icb_block_label_end;
2665 mov(b_ic, ic_block);
2666 L(ic_block_label);
2667 {
2668 compute_ic_block_step(ow, l_pad, r_pad, ic_block_step, 0, 0, 0);
2669 size_t inp_icblk_stride = jcp.is_1stconv && !src_layout_nxc
2670 ? (size_t)jcp.ih * jcp.iw * jcp.id
2671 : 1;
2672 size_t input_offset
2673 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
2674 safe_add(reg_input, input_offset, reg_long_offt);
2675 add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
2676 sub(b_ic, ic_block_step);
2677 if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step);
2678 cmp(b_ic, ic_block_step);
2679 jge(ic_block_label, T_NEAR);
2680 }
2681 L(icb_block_label_end);
2682
2683 const int input_shift = jcp.typesize_in * (jcp.dilate_h + 1)
2684 * inp_icb_sp_stride * inp_mul;
2685
2686 if (generate_icb_loop || ic_tail) {
2687 const size_t kernel_icb_shift = (size_t)jcp.typesize_out * jcp.kd
2688 * jcp.kh * jcp.kw * ic_block * oc_block;
2689 if (generate_icb_loop) {
2690 // icb loop supported for src in nxc layout only
2691 assert(src_layout_nxc);
2692 Label icb_loop_done;
2693 safe_add(reg_kernel,
2694 kernel_icb_shift
2695 - jcp.typesize_out * ic_block * oc_block,
2696 reg_long_offt);
2697 cmp(reg_icb, ic_block);
2698 jge(icb_block_label, T_NEAR);
2699 L(icb_loop_done);
2700 }
2701
2702 L(ic_tail_loop_label);
2703 if (ic_tail) {
2704 Label skip_ic_tail;
2705 const int ic_tail_loop_work = rnd_dn(ic_tail, ic_block_step);
2706 cmp(reg_icb, 0);
2707 jle(skip_ic_tail, T_NEAR);
2708 mov(b_ic, reg_icb);
2709 if (ic_tail_loop_work) {
2710 cmp(reg_icb, ic_block_step);
2711 jge(ic_block_label, T_NEAR);
2712 if (generate_icb_loop) {
2713 // compensate offset added in generate_icb_loop
2714 safe_sub(reg_kernel,
2715 kernel_icb_shift
2716 - jcp.typesize_out * ic_block
2717 * oc_block,
2718 reg_long_offt);
2719 }
2720 }
2721
2722 L(ic_tail_label);
2723 if (ic_tail % ic_block_step) {
2724 cmp(reg_icb, 0);
2725 jle(skip_ic_tail, T_NEAR);
2726 compute_ic_block_step(
2727 ow, l_pad, r_pad, ic_tail % ic_block_step, 0, 0, 0);
2728 }
2729 L(skip_ic_tail);
2730 }
2731
2732 pop(reg_kernel);
2733 pop(reg_input);
2734
2735 add(reg_input, input_shift);
2736 add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
2737
2738 } else if (jcp.is_1stconv && !src_layout_nxc) {
2739 size_t input_offset = (size_t)jcp.typesize_in * jcp.id * jcp.ih
2740 * jcp.iw * ic_block;
2741 safe_sub(reg_input, input_offset, reg_long_offt);
2742 add(reg_input, input_shift);
2743 } else {
2744 add(reg_input, input_shift - jcp.typesize_in * jcp.ic_block);
2745 }
2746
2747 if (!jcp.is_hw_transp && !(generate_icb_loop || ic_tail))
2748 add(reg_kernel,
2749 jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
2750 dec(kj);
2751 cmp(kj, 0);
2752 jg(kh_label, T_NEAR);
2753 }
2754 if (jcp.ndims == 5) {
2755 add(aux_reg_input,
2756 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * jcp.iw
2757 * inp_mul);
2758 add(aux_reg_kernel,
2759 jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block);
2760 dec(ki);
2761 cmp(ki, 0);
2762 jg(kd_label, T_NEAR);
2763 }
2764}
2765
2766void jit_avx512_common_conv_bwd_weights_kernel_f32 ::compute_oh_step_common(
2767 int ic_block_step, int max_ur_w) {
2768 using namespace nstl;
2769 Label kh_label, ic_block_label, ic_tail_loop_label, ic_tail_label, kd_label;
2770
2771 const bool src_layout_nxc = is_src_layout_nxc();
2772 int ic_block = jcp.ic_block;
2773 int oc_block = jcp.oc_block;
2774
2775 int ow = jcp.is_hw_transp ? jcp.oh : jcp.ow;
2776 int r_pad = max(0, jcp.r_pad);
2777 int l_pad = jcp.l_pad;
2778
2779 int ur_w = min(ow, max_ur_w);
2780 int ur_w_trips = ow / ur_w;
2781 int ur_w_tail = ow % ur_w;
2782 if ((ur_w_tail == 0 && r_pad != 0) || (r_pad > 0 && r_pad >= ur_w_tail)) {
2783 if (ur_w_trips > 1) {
2784 ur_w_tail += ur_w;
2785 ur_w_trips--;
2786 } else {
2787 ur_w_tail += (ur_w - ur_w / 2);
2788 ur_w = ur_w / 2;
2789 }
2790 }
2791
2792 assert(l_pad <= max_ur_w);
2793 int inp_mult = src_layout_nxc
2794 ? jcp.ngroups * jcp.ic
2795 : (jcp.is_1stconv ? 1 : ic_block * (jcp.is_hw_transp ? jcp.iw : 1));
2796 int out_mult = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : oc_block;
2797 int input_comeback
2798 = max((ur_w_trips * ur_w * jcp.stride_w - l_pad), 0) * inp_mult;
2799 int output_comeback = ur_w_trips * ur_w * out_mult;
2800 const int ic_tail = jcp.ic_tail;
2801 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
2802
2803 auto ic_loop = [=](int ic_block_step) {
2804 Label ow_block_label, ic_block_inner_label;
2805 int ur_w_blocks = ur_w_trips;
2806
2807 int l_pad_tail = max(l_pad - ur_w, 0);
2808 L(ic_block_inner_label);
2809 if (l_pad != 0) {
2810 ur_w_blocks--;
2811 compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
2812 int iw_offset = ur_w * jcp.stride_w - l_pad;
2813 if (iw_offset > 0)
2814 add(reg_input, jcp.typesize_in * iw_offset * inp_mult);
2815 add(reg_output, jcp.typesize_in * ur_w * out_mult);
2816 }
2817
2818 assert(IMPLICATION(l_pad_tail > 0, ur_w_blocks <= 1));
2819 if (ur_w_blocks > 0) {
2820 xor_(reg_ur_w_trips, reg_ur_w_trips);
2821 L(ow_block_label);
2822 {
2823 compute_ic_block_step(
2824 ur_w, l_pad_tail, 0, ic_block_step, 0, 0, 0);
2825 add(reg_input,
2826 jcp.typesize_in * (ur_w * jcp.stride_w - l_pad_tail)
2827 * inp_mult);
2828 add(reg_output, jcp.typesize_in * ur_w * out_mult);
2829
2830 inc(reg_ur_w_trips);
2831 cmp(reg_ur_w_trips, ur_w_blocks);
2832 jl(ow_block_label, T_NEAR);
2833 l_pad_tail = max(l_pad_tail - ur_w, 0);
2834 }
2835 }
2836
2837 if (ur_w_tail > 0)
2838 compute_ic_block_step(
2839 ur_w_tail, l_pad_tail, r_pad, ic_block_step, 0, 0, 0);
2840
2841 sub(reg_output, jcp.typesize_in * output_comeback);
2842 };
2843
2844 if (jcp.ndims == 5) {
2845 L(kd_label);
2846 mov(reg_input, aux_reg_input);
2847 mov(reg_kernel, aux_reg_kernel);
2848 }
2849
2850 mov(kj, reg_kh);
2851 L(kh_label);
2852 {
2853 Label icb_block_label, icb_block_label_cb;
2854 if (generate_icb_loop || ic_tail) {
2855 // TODO: May be broadcast work?
2856 push(reg_input);
2857 push(reg_kernel);
2858 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
2859 }
2860
2861 if (ic_tail) {
2862 cmp(reg_icb, ic_block);
2863 jl(ic_tail_loop_label, T_NEAR);
2864 }
2865
2866 L(icb_block_label);
2867 mov(b_ic, ic_block);
2868 L(ic_block_label);
2869 Label ic_block_label_end;
2870 {
2871 ic_loop(ic_block_step);
2872 sub(reg_input, jcp.typesize_in * input_comeback);
2873 int inp_icblk_stride = jcp.is_1stconv && !src_layout_nxc
2874 ? jcp.ih * jcp.iw * jcp.id
2875 : 1;
2876 size_t input_offset
2877 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
2878 safe_add(reg_input, input_offset, reg_long_offt);
2879 add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
2880 sub(b_ic, ic_block_step);
2881 if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step);
2882 cmp(b_ic, ic_block_step);
2883 jge(ic_block_label, T_NEAR);
2884 }
2885 L(ic_block_label_end);
2886
2887 const int input_shift
2888 = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * inp_mult;
2889
2890 if (generate_icb_loop || ic_tail) {
2891 const size_t kernel_icb_loop_shift_bytes = (size_t)jcp.typesize_out
2892 * jcp.kd * jcp.kh * jcp.kw * ic_block * oc_block;
2893
2894 if (generate_icb_loop) {
2895 // icb loop supported for src in nxc layout only
2896 assert(src_layout_nxc);
2897 safe_add(reg_kernel,
2898 kernel_icb_loop_shift_bytes
2899 - jcp.typesize_out * ic_block * oc_block,
2900 reg_long_offt);
2901
2902 cmp(reg_icb, ic_block);
2903 jge(icb_block_label, T_NEAR);
2904 }
2905
2906 L(ic_tail_loop_label);
2907 if (ic_tail) {
2908 Label skip_ic_tail;
2909 const int ic_tail_loop_work = rnd_dn(ic_tail, ic_block_step);
2910 cmp(reg_icb, 0);
2911 jle(skip_ic_tail, T_NEAR);
2912 mov(b_ic, reg_icb);
2913 if (ic_tail_loop_work) {
2914 cmp(reg_icb, ic_block_step);
2915 jge(ic_block_label, T_NEAR);
2916 if (generate_icb_loop) {
2917 // compensate offset added in generate_icb_loop
2918 safe_sub(reg_kernel,
2919 kernel_icb_loop_shift_bytes
2920 - jcp.typesize_out * ic_block
2921 * oc_block,
2922 reg_long_offt);
2923 }
2924 }
2925
2926 L(ic_tail_label);
2927 if (ic_tail % ic_block_step) {
2928 cmp(reg_icb, 0);
2929 jle(skip_ic_tail, T_NEAR);
2930 ic_loop(ic_tail % ic_block_step);
2931 }
2932 L(skip_ic_tail);
2933 }
2934
2935 pop(reg_kernel);
2936 pop(reg_input);
2937
2938 add(reg_input, input_shift);
2939 add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
2940 } else if (jcp.is_1stconv && !src_layout_nxc) {
2941 size_t input_offset = (size_t)jcp.typesize_in * jcp.id * jcp.ih
2942 * jcp.iw * ic_block;
2943 safe_sub(reg_input, input_offset, reg_long_offt);
2944 add(reg_input, input_shift);
2945 } else if (!jcp.is_hw_transp) {
2946 add(reg_input, input_shift - jcp.typesize_in * ic_block);
2947 }
2948 if (!jcp.is_hw_transp && !(generate_icb_loop || ic_tail))
2949 add(reg_kernel,
2950 jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
2951 dec(kj);
2952 cmp(kj, 0);
2953 jg(kh_label, T_NEAR);
2954 }
2955 if (jcp.ndims == 5) {
2956 add(aux_reg_input,
2957 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * jcp.iw
2958 * inp_mult);
2959 add(aux_reg_kernel,
2960 jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block);
2961 dec(ki);
2962 cmp(ki, 0);
2963 jg(kd_label, T_NEAR);
2964 }
2965}
2966
2967void jit_avx512_common_conv_bwd_weights_kernel_f32 ::compute_oh_step_disp() {
2968 int ic_block_step;
2969 if (jcp.kernel_kind == expl_bcast)
2970 ic_block_step = jcp.kw <= 3 ? 4 : (jcp.kw <= 7 ? 2 : 1);
2971 else
2972 ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2);
2973
2974 if (jcp.is_1stconv) {
2975 bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0);
2976 ic_block_step = (jcp.kw * jcp.ic_block <= 28 && !large_code)
2977 ? jcp.ic_block
2978 : 1;
2979 }
2980
2981 bool too_large_to_unroll = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1)
2982 && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
2983
2984 int ow = jcp.is_hw_transp ? jcp.oh : jcp.ow;
2985 if (jcp.ndims == 5) {
2986 /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of
2987 * 'movs' must be guaranteed. */
2988 mov(ki, reg_kd_count);
2989 push(reg_kd_count);
2990 mov(aux_reg_input, reg_input);
2991 mov(aux_reg_kernel, reg_kernel);
2992 }
2993
2994 if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll)
2995 compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w);
2996 else if (ow <= max_ur_w)
2997 compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
2998 else
2999 compute_oh_step_common(ic_block_step, max_ur_w);
3000
3001 if (jcp.ndims == 5) {
3002 mov(reg_input, aux_reg_input);
3003 mov(reg_kernel, aux_reg_kernel);
3004 pop(reg_kd_count);
3005 od_step_comeback_pointers();
3006 } else {
3007 oh_step_comeback_pointers();
3008 }
3009}
3010
3011void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel() {
3012 Label skip_zeroing, zeroing_loop;
3013
3014 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3015 cmp(reg_tmp, 0);
3016 jz(skip_zeroing, T_NEAR);
3017
3018 Zmm zero = Zmm(0);
3019 vpxord(zero, zero, zero);
3020 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
3021 const size_t kernel_block_bytes = (size_t)jcp.ic_block * jcp.oc_block
3022 * jcp.kw * jcp.kh * jcp.kd * jcp.typesize_out;
3023 Label icb_block_label, icb_block_label_cb;
3024 if (generate_icb_loop) {
3025 push(reg_kernel);
3026
3027 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3028 L(icb_block_label);
3029 }
3030
3031 xor_(reg_tmp, reg_tmp);
3032 L(zeroing_loop);
3033 {
3034 assert(jcp.oc_block * jcp.typesize_out
3035 == cpu_isa_traits<avx512_core>::vlen);
3036 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3037 vmovups(ptr[reg_kernel + reg_tmp
3038 + ic1 * jcp.oc_block * jcp.typesize_out],
3039 zero);
3040 add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
3041 cmp(reg_tmp, kernel_block_bytes);
3042 jnz(zeroing_loop);
3043 }
3044 if (generate_icb_loop) {
3045 add(reg_kernel, kernel_block_bytes);
3046 sub(reg_icb, jcp.ic_block);
3047 cmp(reg_icb, 0);
3048 jg(icb_block_label, T_NEAR);
3049
3050 pop(reg_kernel);
3051 }
3052
3053 L(skip_zeroing);
3054}
3055
3056void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel_2d() {
3057 assert(jcp.ndims == 4); // only supports 2d
3058 Label skip_bias, bias_loop;
3059 const int oc_tail = jcp.oc_tail;
3060
3061 mov(reg_tmp, ptr[param1 + GET_OFF(flags)]);
3062 mov(reg_bias, ptr[param + GET_OFF(bias)]);
3063 test(reg_tmp, reg_tmp);
3064 jnz(skip_bias, T_NEAR);
3065
3066 vmovups(Zmm(0), ptr[reg_bias]);
3067
3068 mov(reg_oi, jcp.ow);
3069 xor_(reg_tmp, reg_tmp);
3070 L(bias_loop);
3071 {
3072 auto zmm_out = Zmm(1);
3073 if (oc_tail) zmm_out = zmm_out | k_oc_mask | T_z;
3074 vmovups(zmm_out, ptr[reg_output + reg_tmp]);
3075 vaddps(Zmm(0), Zmm(0), Zmm(1));
3076 const int oc_stride
3077 = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block;
3078 add(reg_tmp, jcp.typesize_out * oc_stride);
3079 dec(reg_oi);
3080 jg(bias_loop);
3081 }
3082 vmovups(ptr[reg_bias], Zmm(0));
3083
3084 L(skip_bias);
3085}
3086
3087void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel_3d() {
3088 assert(jcp.ndims == 5); // only supports 3d
3089 Label skip_bias, bias_loop, skip_load_bias;
3090 const bool oc_tail = jcp.oc_tail;
3091
3092 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
3093 test(reg_tmp, reg_tmp);
3094 jne(skip_bias, T_NEAR);
3095
3096 mov(reg_bias, ptr[param + GET_OFF(bias)]);
3097 mov(reg_output, ptr[param + GET_OFF(dst)]);
3098 vpxord(Zmm(1), Zmm(1), Zmm(1));
3099
3100 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3101 cmp(reg_tmp, 0);
3102 jne(skip_load_bias, T_NEAR);
3103 vmovups(Zmm(1), ptr[reg_bias]);
3104
3105 L(skip_load_bias);
3106
3107 mov(reg_oi, ptr[param + GET_OFF(os_index_end)]);
3108 sub(reg_oi, ptr[param + GET_OFF(os_index_begin)]);
3109 cmp(reg_oi, 0);
3110 jle(skip_bias, T_NEAR); // no iterations along depth dimension
3111
3112 const size_t oc_mult
3113 = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block;
3114 mov(reg_tmp, oc_mult * jcp.ow * jcp.oh * jcp.typesize_out);
3115 imul(reg_oi, reg_tmp);
3116
3117 xor_(reg_tmp, reg_tmp);
3118 L(bias_loop);
3119 {
3120 auto zmm_out = Zmm(0);
3121 if (oc_tail) zmm_out = zmm_out | k_oc_mask | T_z;
3122 vmovups(zmm_out, ptr[reg_output + reg_tmp]);
3123 vaddps(Zmm(1), Zmm(1), Zmm(0));
3124 add(reg_tmp, oc_mult * jcp.typesize_out);
3125 cmp(reg_tmp, reg_oi);
3126 jl(bias_loop);
3127 }
3128 vmovups(ptr[reg_bias], Zmm(1));
3129
3130 L(skip_bias);
3131}
3132
3133void jit_avx512_common_conv_bwd_weights_kernel_f32 ::compute_oh_loop_common() {
3134 assert(one_of(jcp.harness, harness_mb_reduction, harness_3d_reduction));
3135 int b_pad = jcp.b_pad;
3136 int t_pad = jcp.t_pad;
3137 bool is_dilated = jcp.dilate_h != 0;
3138 int dilate_h = jcp.dilate_h + 1;
3139 int stride_h = jcp.stride_h;
3140 const int inp_mult = is_src_layout_nxc()
3141 ? jcp.ngroups * jcp.ic
3142 : (jcp.is_1stconv ? 1 : jcp.ic_block);
3143 const int out_mult
3144 = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block;
3145 int iw = jcp.is_hw_transp ? 1 : jcp.iw;
3146 Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label,
3147 oh_bpad_label, oh_bpad_label_end, oh_dilate_label_shift,
3148 oh_dilate_label_noshift, oh_dilate_label_end;
3149
3150 int ow = jcp.is_hw_transp ? jcp.oh : jcp.ow;
3151 int oh = jcp.is_hw_transp ? jcp.ow : jcp.oh;
3152 int kw = jcp.is_hw_transp ? jcp.tr_kw : jcp.kw;
3153 int kh = jcp.is_hw_transp ? jcp.tr_kh : jcp.kh;
3154 int ih = jcp.is_hw_transp ? jcp.tr_ih : jcp.ih;
3155 int ihp = jcp.is_hw_transp ? jcp.tr_ih : jcp.ihp;
3156
3157 assert(IMPLICATION(jcp.is_hw_transp,
3158 everyone_is(1, oh, stride_h, dilate_h)
3159 && everyone_is(0, b_pad, t_pad)));
3160
3161 mov(reg_kh, kh);
3162 xor_(reg_oj, reg_oj);
3163 /* Compute 'top' edge */
3164 if (t_pad > 0) {
3165 const int kh_range = 1 + (kh - 1) * dilate_h;
3166 const int overflow = nstl::max(0, kh - div_up(t_pad + ih, dilate_h));
3167 const int underflow = div_up(t_pad, dilate_h);
3168 const int initial_inp_ker_overlap = kh - overflow - underflow;
3169 mov(reg_kh, initial_inp_ker_overlap);
3170 add(reg_kernel,
3171 jcp.typesize_out * underflow * kw * jcp.ic_block
3172 * jcp.oc_block);
3173 // generate loop to process kernel while it remains within t_pad + ih
3174 if (kh_range < t_pad + ih) {
3175 if (is_dilated) {
3176 const int tail = t_pad % dilate_h;
3177 const int shift = tail == 0 ? 0 : dilate_h - tail;
3178 mov(reg_tmp, shift);
3179 if (tail != 0)
3180 add(reg_input, jcp.typesize_in * shift * iw * inp_mult);
3181 }
3182 L(oh_tpad_label);
3183 {
3184 cmp(reg_oj, oh);
3185 jge(oh_label_end, T_NEAR);
3186
3187 compute_oh_step_disp();
3188 add(reg_output, jcp.typesize_in * ow * out_mult);
3189 if (is_dilated) {
3190 inc(reg_tmp);
3191 cmp(reg_tmp, dilate_h);
3192 jl(oh_dilate_label_shift, T_NEAR);
3193 // unshift input as new kernel element enters
3194 sub(reg_input,
3195 jcp.typesize_in * (dilate_h - 1) * iw * inp_mult);
3196 xor_(reg_tmp, reg_tmp);
3197 }
3198 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3199 sub(reg_kernel,
3200 jcp.typesize_out * stride_h * kw * jcp.ic_block
3201 * jcp.oc_block);
3202 add(reg_kh, stride_h);
3203 if (is_dilated) {
3204 jmp(oh_dilate_label_noshift, T_NEAR);
3205 L(oh_dilate_label_shift);
3206 // shift input as old kernel element progresses
3207 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3208 L(oh_dilate_label_noshift);
3209 }
3210 inc(reg_oj);
3211
3212 // final number of kernel elements that overlap with input
3213 const int final_inp_ker_overlap
3214 = nstl::min(kh, div_up(ih, dilate_h));
3215 cmp(reg_kh, final_inp_ker_overlap);
3216 jl(oh_tpad_label, T_NEAR);
3217 }
3218 }
3219 // need second loop to process kernel if it is larger than the input
3220 // (does not apply to dilations as they must have unit stride)
3221 if (kh_range
3222 >= ih + (t_pad % stride_h == 0 ? stride_h : t_pad % stride_h)) {
3223 assert(!is_dilated);
3224 mov(reg_kh, ih);
3225 L(oh_tpad_tail_label);
3226 {
3227 cmp(reg_oj, oh);
3228 jge(oh_label_end, T_NEAR);
3229
3230 compute_oh_step_disp();
3231 add(reg_output, jcp.typesize_in * ow * out_mult);
3232 sub(reg_kernel,
3233 jcp.typesize_out * stride_h * kw * jcp.ic_block
3234 * jcp.oc_block);
3235
3236 inc(reg_oj);
3237 cmp(reg_oj, nstl::min(utils::div_up(t_pad, stride_h), oh));
3238 jl(oh_tpad_tail_label, T_NEAR);
3239 }
3240 }
3241 // correct any excess shifts to kernel and input
3242 // (does not apply to dilations as they must have unit stride,
3243 // kernel must fit inside input, and padding is smaller than input)
3244 if (t_pad <= oh * stride_h) {
3245 // kernel has moved beyond padding (adjust for stride effects)
3246 if (t_pad % stride_h != 0) {
3247 assert(!is_dilated);
3248 int inp_corr = stride_h - t_pad % stride_h;
3249 add(reg_kernel,
3250 jcp.typesize_out * inp_corr * kw * jcp.ic_block
3251 * jcp.oc_block);
3252 add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult);
3253 }
3254 } else {
3255 // kernel still overlaps padding (complete reset)
3256 assert(!is_dilated);
3257 sub(reg_kernel,
3258 jcp.typesize_out * (t_pad - oh * stride_h) * kw
3259 * jcp.ic_block * jcp.oc_block);
3260 }
3261 }
3262
3263 const int oj_end_value = nstl::min(
3264 oh, utils::div_up(ihp - b_pad - (kh - 1) * dilate_h, stride_h));
3265 cmp(reg_oj, oj_end_value);
3266 jge(oh_label_end, T_NEAR);
3267
3268 /* Compute middle block(s) */
3269 mov(reg_kh, kh);
3270 L(oh_label);
3271 {
3272 compute_oh_step_disp();
3273 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3274 add(reg_output, jcp.typesize_in * ow * out_mult);
3275
3276 inc(reg_oj);
3277 cmp(reg_oj, oj_end_value);
3278 jl(oh_label, T_NEAR);
3279 }
3280 L(oh_label_end);
3281
3282 /* Compute bottom edge */
3283 if (b_pad > 0) {
3284 cmp(reg_oj, oh);
3285 jge(oh_bpad_label_end, T_NEAR);
3286
3287 if (is_dilated) {
3288 mov(reg_kh, kh - 1); // assumes unit stride for dilations
3289 mov(reg_tmp, 0);
3290 } else {
3291 mov(reg_kh, ihp - b_pad);
3292 imul(reg_tmp, reg_oj, stride_h);
3293 sub(reg_kh, reg_tmp);
3294 }
3295 L(oh_bpad_label);
3296 {
3297 compute_oh_step_disp();
3298 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3299 add(reg_output, jcp.typesize_in * ow * out_mult);
3300 if (is_dilated) {
3301 inc(reg_tmp);
3302 cmp(reg_tmp, dilate_h);
3303 jl(oh_dilate_label_end, T_NEAR);
3304 xor_(reg_tmp, reg_tmp);
3305 }
3306 sub(reg_kh, stride_h);
3307 cmp(reg_kh, 0);
3308 jle(oh_bpad_label_end, T_NEAR);
3309 if (is_dilated) L(oh_dilate_label_end);
3310
3311 inc(reg_oj);
3312 cmp(reg_oj, oh);
3313 jl(oh_bpad_label, T_NEAR);
3314 }
3315 L(oh_bpad_label_end);
3316 }
3317}
3318
3319void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_oh_loop_partial() {
3320 assert(jcp.harness == harness_2d_reduction);
3321 int ic_block = jcp.ic_block;
3322 int oc_block = jcp.oc_block;
3323 const int inp_mult = is_src_layout_nxc()
3324 ? jcp.ngroups * jcp.ic
3325 : (jcp.is_1stconv ? 1 : jcp.ic_block);
3326 const int out_mult
3327 = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block;
3328 const int input_bottom_padding_overlap
3329 = div_up(jcp.ih + jcp.t_pad - (jcp.kh - 1), jcp.stride_h);
3330 const int bottom_pad_input_correction
3331 = jcp.ih + jcp.t_pad - input_bottom_padding_overlap * jcp.stride_h;
3332
3333 const size_t filter_shift = jcp.typesize_out * jcp.kw * ic_block * oc_block;
3334 const size_t input_shift = jcp.typesize_in * jcp.iw * inp_mult;
3335 const size_t output_shift = jcp.typesize_out * jcp.ow * out_mult;
3336
3337 Label loop_begin_label, loop_end_label, common_block_label,
3338 top_padding_end_label, bottom_padding_end_label,
3339 bottom_padding_label;
3340
3341 if (jcp.with_bias) {
3342 Label skip_zero_bias;
3343 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
3344 mov(reg_tmp, ptr[param1 + GET_OFF(channel)]);
3345 test(reg_tmp, reg_tmp);
3346 jz(skip_zero_bias, T_NEAR);
3347 mov(reg_tmp, ptr[param1 + GET_OFF(flags)]);
3348 test(reg_tmp, reg_tmp);
3349 jnz(skip_zero_bias, T_NEAR);
3350 vpxord(Zmm(1), Zmm(1), Zmm(1));
3351 vmovups(ptr[reg_bias], Zmm(1));
3352 L(skip_zero_bias);
3353 }
3354
3355 /* Offset filter position to adjust for top padding */
3356 add(reg_kernel, ptr[param + GET_OFF(kh_offset)]);
3357
3358 mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
3359 mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
3360
3361 cmp(reg_kh, 0);
3362 jle(loop_end_label, T_NEAR); // no iterations along kh
3363 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3364 jge(loop_end_label, T_NEAR); // no iterations along height dimension
3365
3366 L(loop_begin_label);
3367
3368 if (jcp.with_bias) bias_kernel_2d();
3369 compute_oh_step_disp();
3370
3371 /* Compute 'top' edge */
3372 if (jcp.t_pad > 0) {
3373
3374 /* Check if within top padding region */
3375 cmp(reg_oj, div_up(jcp.t_pad, jcp.stride_h));
3376 jge(top_padding_end_label, T_NEAR);
3377
3378 /* Increment step counter and adjust filter position */
3379 sub(reg_kernel, filter_shift * jcp.stride_h);
3380 add(reg_kh, jcp.stride_h);
3381
3382 /* Final number of kernel elements that overlap with input */
3383 const int inp_ker_overlap = nstl::min(jcp.kh, jcp.ih);
3384 cmp(reg_kh, inp_ker_overlap);
3385 jle(common_block_label, T_NEAR);
3386
3387 /* Correct any excess shifts to kernel and input */
3388 if (jcp.t_pad <= jcp.oh * jcp.stride_h) {
3389 /* Filter has moved beyond padding (adjust for stride effects) */
3390 if (jcp.t_pad % jcp.stride_h != 0) {
3391 int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h;
3392 add(reg_kernel, filter_shift * inp_corr);
3393 add(reg_input, input_shift * inp_corr);
3394 }
3395 } else {
3396 /* Filter still overlaps padding (complete reset) */
3397 sub(reg_kernel, (jcp.t_pad - jcp.oh * jcp.stride_h) * filter_shift);
3398 }
3399
3400 /* Set filter element count for outside the t_pad region */
3401 if (jcp.t_pad + jcp.ih < jcp.kh + jcp.stride_h) {
3402 // filter now overlaps with b_pad
3403 mov(reg_kh, bottom_pad_input_correction);
3404 } else {
3405 mov(reg_kh, inp_ker_overlap);
3406 }
3407
3408 jmp(common_block_label);
3409
3410 L(top_padding_end_label);
3411 }
3412
3413 /* Compute 'bottom' edge */
3414 if (jcp.b_pad > 0) {
3415
3416 /* Check if within bottom padding region */
3417 cmp(reg_oj, input_bottom_padding_overlap - 1);
3418 jl(bottom_padding_end_label, T_NEAR);
3419 jg(bottom_padding_label, T_NEAR);
3420
3421 /* Execute overlap correction between the filter and the initial
3422 * bottom padding region. */
3423 mov(reg_kh, bottom_pad_input_correction);
3424 jmp(bottom_padding_end_label, T_NEAR);
3425
3426 L(bottom_padding_label);
3427 sub(reg_kh, jcp.stride_h);
3428 cmp(reg_kh, 0);
3429 jle(loop_end_label, T_NEAR);
3430
3431 L(bottom_padding_end_label);
3432 }
3433
3434 /* Compute middle block */
3435 add(reg_input, input_shift * jcp.stride_h);
3436
3437 /* Execute common block and loop */
3438 L(common_block_label);
3439 add(reg_output, output_shift);
3440 inc(reg_oj);
3441 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3442 jl(loop_begin_label, T_NEAR);
3443
3444 L(loop_end_label);
3445}
3446
3447void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_od_loop_partial() {
3448 assert(jcp.harness == harness_3d_reduction);
3449 int ic_block = jcp.ic_block;
3450 int oc_block = jcp.oc_block;
3451 const int inp_mult = is_src_layout_nxc()
3452 ? jcp.ngroups * jcp.ic
3453 : (jcp.is_1stconv ? 1 : jcp.ic_block);
3454 const int out_mult
3455 = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block;
3456
3457 int iw = jcp.iw;
3458 int ow = jcp.ow;
3459 const int input_backpad_overlap
3460 = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
3461 const int back_pad_input_correction
3462 = jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d;
3463
3464 const size_t filter_shift
3465 = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block;
3466 const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult;
3467 const size_t output_shift = jcp.typesize_in * jcp.oh * ow * out_mult;
3468
3469 Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
3470 backpad_end_label, backpad_label;
3471
3472 if (jcp.with_bias) bias_kernel_3d();
3473
3474 /* initially offset 'kd' by f_pad */
3475 add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
3476
3477 mov(reg_input_d, ptr[param + GET_OFF(src)]);
3478 mov(reg_output_d, ptr[param + GET_OFF(dst)]);
3479 mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]);
3480 mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
3481
3482 cmp(reg_kd_count, 0);
3483 jle(loop_end_label, T_NEAR); // no iterations along kd
3484 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
3485 jge(loop_end_label, T_NEAR); // no iterations along depth dimension
3486
3487 L(d_loop_label);
3488
3489 mov(reg_input, reg_input_d);
3490 mov(reg_output, reg_output_d);
3491
3492 push(reg_input_d);
3493 push(reg_output_d);
3494 push(reg_d_index);
3495
3496 compute_oh_loop_common();
3497
3498 pop(reg_d_index);
3499 pop(reg_output_d);
3500 pop(reg_input_d);
3501
3502 /* Compute 'front' edge */
3503 if (jcp.f_pad > 0) {
3504
3505 /* Check if within fpad region */
3506 cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
3507 jge(fpad_end_label, T_NEAR);
3508
3509 /* Fpad steps */
3510 sub(reg_kernel, filter_shift * jcp.stride_d);
3511 add(reg_kd_count, jcp.stride_d);
3512
3513 /* Final number of kernel elements that overlap with input */
3514 const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id);
3515 cmp(reg_kd_count, inp_ker_overlap);
3516 jle(common_block_label, T_NEAR);
3517
3518 /* Correct any excess shifts to kernel and input */
3519 if (jcp.f_pad <= jcp.od * jcp.stride_d) {
3520 /* Filter has moved beyond padding (adjust for stride effects) */
3521 if (jcp.f_pad % jcp.stride_d != 0) {
3522 int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
3523 add(reg_kernel, filter_shift * inp_corr);
3524 add(reg_input_d, input_shift * inp_corr);
3525 }
3526 } else {
3527 /* Filter still overlaps padding (complete reset) */
3528 sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
3529 }
3530
3531 /* Set filter element count for outside the f_pad region */
3532 if (jcp.f_pad + jcp.id < jcp.kd + jcp.stride_d) {
3533 // filter now overlaps with back_pad
3534 mov(reg_kd_count, back_pad_input_correction);
3535 } else {
3536 mov(reg_kd_count, inp_ker_overlap);
3537 }
3538
3539 jmp(common_block_label);
3540
3541 L(fpad_end_label);
3542 }
3543
3544 /* Compute bottom edge */
3545 if (jcp.back_pad > 0) {
3546
3547 /* Check if within back_pad region */
3548 cmp(reg_d_index, input_backpad_overlap - 1);
3549 jl(backpad_end_label, T_NEAR);
3550 jg(backpad_label, T_NEAR);
3551
3552 /* Execute overlap correction between the filter and the initial
3553 * back_pad region. */
3554 mov(reg_kd_count, back_pad_input_correction);
3555 jmp(backpad_end_label, T_NEAR);
3556
3557 L(backpad_label);
3558 sub(reg_kd_count, jcp.stride_d);
3559 cmp(reg_kd_count, 0);
3560 jle(loop_end_label, T_NEAR);
3561
3562 L(backpad_end_label);
3563 }
3564
3565 /* Compute middle block */
3566 add(reg_input_d, input_shift * jcp.stride_d);
3567
3568 /* Execute common block and loop */
3569 L(common_block_label);
3570 add(reg_output_d, output_shift);
3571 inc(reg_d_index);
3572 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
3573 jl(d_loop_label, T_NEAR);
3574
3575 L(loop_end_label);
3576}
3577
3578void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop() {
3579
3580 maybe_zero_kernel();
3581
3582 switch (jcp.harness) {
3583 case harness_2d_reduction: compute_oh_loop_partial(); break;
3584 case harness_3d_reduction: compute_od_loop_partial(); break;
3585 case harness_mb_reduction: compute_oh_loop_common(); break;
3586 case harness_nxc: break;
3587 default: assert(!"Invalid harness type");
3588 }
3589}
3590
3591void jit_avx512_common_conv_bwd_weights_kernel_f32::generate_microkernel() {
3592
3593 reg64_t reg_dwei = abi_param1;
3594 reg64_t reg_src = abi_param2;
3595 reg64_t reg_ddst = abi_param3;
3596 reg64_t reg_iw_base = abi_param4;
3597 reg64_t aux_reg_icb = r10;
3598 reg64_t aux_reg_kwb = r11;
3599 reg64_t reg_src_save = r12;
3600 reg64_t reg_dwei_save = r13;
3601 reg64_t reg_iw_base_save = r14;
3602 reg64_t reg_tmp = r15;
3603
3604 //Currently kernel is small so passing parameters via registers is preferred
3605 //whenever possible
3606#ifdef _WIN32
3607 // Must be a scratch register since load is before preamble
3608 reg64_t reg_owb = rax;
3609 mov(reg_owb, ptr[get_stack_params_address(false)]);
3610#else
3611 reg64_t reg_owb = abi_param5;
3612#endif
3613
3614 preamble();
3615
3616 const int kw_unroll = jcp.ur_kw;
3617 const int ow_unroll = jcp.ur_ow;
3618 const int iw_unroll = ow_unroll + kw_unroll - 1;
3619 const int ic_unroll = jcp.ur_ic;
3620
3621 const int ker_reg_count = ic_unroll;
3622 const int src_reg_count = iw_unroll * ic_unroll;
3623 const int ddst_reg_count = ow_unroll;
3624
3625 MAYBE_UNUSED(ddst_reg_count);
3626 assert(ker_reg_count + src_reg_count + ddst_reg_count <= 32);
3627
3628 auto dwei_offset = [&](int i_kw, int i_ic) {
3629 const int oc_block_size = sizeof(float);
3630 const int ic_block_size = jcp.oc_block * oc_block_size;
3631 const int kw_block_size = jcp.ic_block * ic_block_size;
3632 const int kh_block_size = jcp.kw * kw_block_size;
3633 const int kd_block_size = jcp.kh * kh_block_size;
3634 const int icb_block_size = jcp.kd * kd_block_size;
3635
3636 int icb = i_ic / jcp.ic_block;
3637 i_ic = i_ic % jcp.ic_block;
3638
3639 return icb * icb_block_size + i_kw * kw_block_size
3640 + i_ic * ic_block_size;
3641 };
3642
3643 auto src_offset = [&](int i_ic, int i_iw) {
3644 const int ic_block_size = sizeof(float);
3645 const int g_block_size = jcp.ic * ic_block_size;
3646 const int iw_block_size = jcp.ngroups * g_block_size;
3647
3648 return i_iw * iw_block_size + i_ic * ic_block_size;
3649 };
3650
3651 auto ddst_offset = [&](int i_ow) {
3652 const int oc_block_size = sizeof(float);
3653 const int g_block_size = jcp.oc * oc_block_size;
3654 const int ow_block_size = jcp.ngroups * g_block_size;
3655
3656 return i_ow * ow_block_size;
3657 };
3658
3659 auto get_src_zmm = [=](int iw_index, int i_ic) {
3660 int zmm_index = iw_index * ic_unroll + i_ic + ker_reg_count;
3661 return Zmm(zmm_index);
3662 };
3663
3664 auto get_ddst_zmm = [=](int i_ow) {
3665 int zmm_index = i_ow + src_reg_count + ker_reg_count;
3666 return Zmm(zmm_index);
3667 };
3668
3669 auto get_ker_zmm = [=](int i_ic) { return Zmm(i_ic); };
3670
3671 auto load_ddsts = [=](int ur_ow) {
3672 for (int i_ow = 0; i_ow < ur_ow; i_ow++) {
3673 vmovups(get_ddst_zmm(i_ow), zword[reg_ddst + ddst_offset(i_ow)]);
3674 }
3675 };
3676
3677 auto load_srcs = [=](int ur_iw, int ur_ic, bool is_iw_edge) {
3678 Label iw_load_end;
3679 if (is_iw_edge) {
3680 for_(int i_iw_index = 0; i_iw_index < ur_iw; i_iw_index++)
3681 for (int i_ic = 0; i_ic < ur_ic; i_ic++) {
3682 vpxord(get_src_zmm(i_iw_index, i_ic),
3683 get_src_zmm(i_iw_index, i_ic),
3684 get_src_zmm(i_iw_index, i_ic));
3685 }
3686 }
3687
3688 for (int i_iw_index = 0; i_iw_index < ur_iw; i_iw_index++) {
3689 Label ic_load_end;
3690 if (is_iw_edge) {
3691 cmp(reg_iw_base, jcp.iw - i_iw_index * jcp.stride_w);
3692 jge(iw_load_end, T_NEAR);
3693 if (jcp.l_pad > 0) {
3694 cmp(reg_iw_base, -i_iw_index * jcp.stride_w);
3695 jl(ic_load_end, T_NEAR);
3696 }
3697 }
3698 for (int i_ic = 0; i_ic < ur_ic; i_ic++) {
3699 vbroadcastss(get_src_zmm(i_iw_index, i_ic),
3700 zword[reg_src
3701 + src_offset(i_ic, jcp.stride_w * i_iw_index)]);
3702 }
3703 L(ic_load_end);
3704 }
3705 L(iw_load_end);
3706 };
3707
3708 auto compute_kernel = [=](int ur_ow, int ur_ic, int ur_kw, int is_iw_edge) {
3709 Label kw_loop_end;
3710 load_srcs(ur_ow + ur_kw - 1, ur_ic, is_iw_edge);
3711
3712 for (int i_kw = 0; i_kw < ur_kw; i_kw++) {
3713 for (int i_ic = 0; i_ic < ur_ic; i_ic++) {
3714 vpxord(get_ker_zmm(i_ic), get_ker_zmm(i_ic), get_ker_zmm(i_ic));
3715 }
3716 for (int i_ow = 0; i_ow < ur_ow; i_ow++) {
3717 for (int i_ic = 0; i_ic < ur_ic; i_ic++) {
3718 vfmadd231ps(get_ker_zmm(i_ic),
3719 get_src_zmm(i_ow + i_kw, i_ic), get_ddst_zmm(i_ow));
3720 }
3721 }
3722 for (int i_ic = 0; i_ic < ur_ic; i_ic++) {
3723 int ker_offset = dwei_offset(i_kw, i_ic);
3724 vaddps(get_ker_zmm(i_ic), zword[reg_dwei + ker_offset]);
3725 vmovups(zword[reg_dwei + ker_offset], get_ker_zmm(i_ic));
3726 }
3727 }
3728
3729 L(kw_loop_end);
3730 };
3731
3732 auto kw_loop = [=](int ur_ow, int ur_ic, int is_iw_edge) {
3733 Label kwb_loop_begin, kwb_loop_end;
3734 int kw_tail = jcp.kw % kw_unroll;
3735 int kw_iter = jcp.kw / kw_unroll;
3736
3737 if (kw_iter > 0) {
3738 if (kw_iter > 1) {
3739 mov(aux_reg_kwb, jcp.kw - kw_tail);
3740 L(kwb_loop_begin);
3741 }
3742 compute_kernel(ur_ow, ur_ic, kw_unroll, is_iw_edge);
3743
3744 if (kw_iter > 1 || kw_tail) {
3745 add(reg_iw_base, (jcp.dilate_w + 1) * kw_unroll);
3746 add(reg_src, src_offset(0, (jcp.dilate_w + 1) * kw_unroll));
3747 add(reg_dwei, dwei_offset(kw_unroll, 0));
3748 }
3749
3750 if (kw_iter > 1) {
3751 sub(aux_reg_kwb, kw_unroll);
3752 jg(kwb_loop_begin, T_NEAR);
3753 }
3754 }
3755
3756 if (kw_tail) compute_kernel(ur_ow, ur_ic, kw_tail, is_iw_edge);
3757
3758 L(kwb_loop_end);
3759 };
3760
3761 auto ic_loop = [=](int ur_ow, int is_iw_edge) {
3762 Label icb_loop_begin, icb_loop_end;
3763 int ic_tail = jcp.ic % ic_unroll;
3764 int ic_iter = jcp.ic / ic_unroll;
3765
3766 if (ic_iter > 0) {
3767 if (ic_iter > 1 || ic_tail) {
3768 mov(aux_reg_icb, jcp.ic - ic_tail);
3769 L(icb_loop_begin);
3770 // Saving onto the stack here appears to significantly slow down
3771 // code execution. If this kernel runs out of registers, getting
3772 // rid of the *_save registers should be possible by using
3773 // subtracts to restore the value and maintain performance.
3774 mov(reg_src_save, reg_src);
3775 mov(reg_dwei_save, reg_dwei);
3776 mov(reg_iw_base_save, reg_iw_base);
3777 }
3778
3779 kw_loop(ur_ow, ic_unroll, is_iw_edge);
3780
3781 if (ic_iter > 1 || ic_tail) {
3782 mov(reg_iw_base, reg_iw_base_save);
3783 mov(reg_dwei, reg_dwei_save);
3784 mov(reg_src, reg_src_save);
3785
3786 Label inter_block_increment, increment_finish;
3787 sub(aux_reg_icb, ic_unroll);
3788 if (jcp.ic > jcp.ic_block) {
3789 const int log2_ic_block = 4;
3790 lea(reg_tmp, ptr[aux_reg_icb - jcp.ic - ic_tail]);
3791 test(reg_tmp, (1 << log2_ic_block) - 1);
3792 jnz(inter_block_increment, T_NEAR);
3793
3794 add(reg_dwei,
3795 dwei_offset(0, jcp.ic_block)
3796 - dwei_offset(0, jcp.ic_block - ic_unroll));
3797 jmp(increment_finish);
3798 L(inter_block_increment);
3799 }
3800 add(reg_dwei, dwei_offset(0, ic_unroll));
3801 L(increment_finish);
3802
3803 add(reg_src, src_offset(ic_unroll, 0));
3804 }
3805 if (ic_iter > 1) {
3806 cmp(aux_reg_icb, 0);
3807 jg(icb_loop_begin, T_NEAR);
3808 }
3809 }
3810
3811 if (ic_tail) kw_loop(ur_ow, ic_tail, is_iw_edge);
3812
3813 L(icb_loop_end);
3814 };
3815
3816 auto ic_loop_dispatch = [=](int ur_ow) {
3817 Label iw_edge_case, ic_end;
3818
3819 const int iw_overflow_bound = jcp.iw - (ur_ow - 1) * jcp.stride_w
3820 - (jcp.kw - 1) * (jcp.dilate_w + 1);
3821 cmp(reg_iw_base, iw_overflow_bound);
3822 jge(iw_edge_case, T_NEAR);
3823 if (jcp.l_pad > 0) {
3824 cmp(reg_iw_base, 0);
3825 jl(iw_edge_case, T_NEAR);
3826 }
3827
3828 ic_loop(ur_ow, false);
3829 jmp(ic_end, T_NEAR);
3830
3831 L(iw_edge_case);
3832 ic_loop(ur_ow, true);
3833
3834 L(ic_end);
3835 };
3836
3837 Label ow_end, ow_tail;
3838 int ow_tail_size = jcp.ow % ow_unroll;
3839 cmp(reg_owb, jcp.ow - ow_tail_size);
3840 jge(ow_tail, T_NEAR);
3841
3842 load_ddsts(ow_unroll);
3843 ic_loop_dispatch(ow_unroll);
3844 jmp(ow_end, T_NEAR);
3845
3846 L(ow_tail);
3847 load_ddsts(ow_tail_size);
3848 ic_loop_dispatch(ow_tail_size);
3849
3850 L(ow_end);
3851
3852 postamble();
3853 ret();
3854}
3855
3856void jit_avx512_common_conv_bwd_weights_kernel_f32::generate_kernel() {
3857 preamble();
3858
3859 mov(reg_input, ptr[param + GET_OFF(src)]);
3860 mov(reg_output, ptr[param + GET_OFF(dst)]);
3861 mov(reg_kernel, ptr[param + GET_OFF(filt)]);
3862
3863 const int oc_tail = jcp.oc_tail;
3864 if (oc_tail) {
3865 Label skip;
3866 Reg32 reg_tail_32 = reg_oc_tail.cvt32();
3867 if (jcp.nb_oc > 1) {
3868 kxnorw(k_oc_mask, k_oc_mask, k_oc_mask);
3869 mov(reg_oc_tail, ptr[param + GET_OFF(load_work)]);
3870 cmp(reg_oc_tail, 16);
3871 je(skip, T_NEAR);
3872 }
3873 mov(reg_tail_32, (1 << oc_tail) - 1);
3874 kmovw(k_oc_mask, reg_tail_32);
3875 L(skip);
3876 }
3877 compute_loop();
3878
3879 postamble();
3880}
3881
3882status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
3883 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
3884 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
3885 memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
3886 if (!mayiuse(avx512_core)) return status::unimplemented;
3887
3888 const memory_desc_wrapper src_d(&src_md);
3889 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
3890 const memory_desc_wrapper diff_bias_d(&diff_bias_md);
3891 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
3892
3893 if (!utils::everyone_is(data_type::f32, src_d.data_type(),
3894 diff_weights_d.data_type(), diff_dst_d.data_type()))
3895 return status::unimplemented;
3896
3897 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
3898 int ndims = src_d.ndims();
3899
3900 jcp = zero<decltype(jcp)>();
3901
3902 jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / typesize;
3903 jcp.nthr = jcp.aligned_threads = nthreads;
3904 jcp.ndims = ndims;
3905 jcp.prop_kind = cd.prop_kind;
3906
3907 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
3908 jcp.mb = src_d.dims()[0];
3909
3910 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
3911 jcp.oc_without_padding = jcp.oc;
3912 jcp.ic = src_d.dims()[1] / jcp.ngroups;
3913 jcp.ic_without_padding = jcp.ic;
3914
3915 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
3916 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
3917 jcp.iw = src_d.dims()[ndims - 1];
3918 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
3919 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
3920 jcp.ow = diff_dst_d.dims()[ndims - 1];
3921
3922 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
3923 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
3924 jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
3925
3926 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
3927 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
3928 jcp.l_pad = cd.padding[0][ndims - 3];
3929
3930 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
3931 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
3932 jcp.stride_w = cd.strides[ndims - 3];
3933
3934 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
3935 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
3936 jcp.dilate_w = cd.dilates[ndims - 3];
3937
3938 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
3939 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
3940 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
3941
3942 bool ok = true
3943 // general condition to simplify dilations
3944 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
3945 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
3946 // special condition to simplify dilations in compute_oh_loop_common
3947 && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih);
3948 if (!ok) return status::unimplemented;
3949
3950 jcp.r_pad = nstl::max(0,
3951 calculate_end_padding(
3952 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
3953 jcp.b_pad = nstl::max(0,
3954 calculate_end_padding(
3955 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
3956 jcp.back_pad = nstl::max(0,
3957 calculate_end_padding(
3958 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
3959
3960 /* XXX: currently, does not support dilation_d > 0 */
3961 if (ndims == 5 && jcp.dilate_d > 0) return status::unimplemented;
3962
3963 /* Set bounds for large filter 'kw > 14' support and optimized JIT
3964 * implementation for small output-width 'ow = 1' */
3965 const int min_filter_size = 14;
3966 const int max_filter_size = 20;
3967 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
3968 const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
3969 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
3970 auto curr_src_tag = src_d.matches_one_of_tag(
3971 dat_tag_nxc, dat_tag_nCx16c, dat_tag_ncx);
3972 auto curr_dst_tag
3973 = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
3974 bool is_data_layout_nxc = IMPLICATION(curr_src_tag != dat_tag_nxc,
3975 src_d.format_kind() == format_kind::any)
3976 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
3977 diff_dst_d.format_kind() == format_kind::any)
3978 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
3979
3980 /* Optimization: when `output-width == 1' deploy a special case of the
3981 * JIT-Kernel by unrolling with regards to height instead of width for
3982 * the source and filter tensors. The JIT-Kernel also transposes the
3983 * strides for the input and filter memory access. */
3984 jcp.is_hw_transp = !is_data_layout_nxc && ndims == 4
3985 && jcp.kw >= min_filter_size && jcp.kw < max_filter_size
3986 && jcp.ow == 1 && jcp.kw == jcp.iw
3987 && everyone_is(1, jcp.stride_w, jcp.stride_h)
3988 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
3989 && everyone_is(0, jcp.l_pad, jcp.t_pad, jcp.r_pad, jcp.b_pad);
3990 if (jcp.is_hw_transp) {
3991 jcp.tr_kw = jcp.kh;
3992 jcp.tr_kh = jcp.kw;
3993 jcp.tr_iw = jcp.ih;
3994 jcp.tr_ih = jcp.iw;
3995 }
3996
3997 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
3998 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
3999 jcp.ohp = jcp.oh;
4000 jcp.owp = jcp.ow;
4001 jcp.aligned_threads = 0;
4002
4003 /* check for the 1st convolution */
4004 jcp.is_1stconv = is_1stconv(jcp);
4005
4006 jcp.oc_block = jcp.simd_w;
4007
4008 bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1
4009 && src_d.data_type() == data_type::f32;
4010
4011 if (ok_to_pad_channels) jcp.oc = rnd_up(jcp.oc, jcp.simd_w);
4012
4013 if (!IMPLICATION(!is_data_layout_nxc, jcp.oc % jcp.oc_block == 0))
4014 return status::unimplemented;
4015 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0;
4016 jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.simd_w : 0;
4017
4018 auto dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
4019 auto wei_tag = with_groups
4020 ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
4021 : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
4022
4023 if (diff_dst_md.format_kind == format_kind::any) {
4024 CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag));
4025 } else if (curr_dst_tag != dst_tag)
4026 return status::unimplemented;
4027 jcp.dst_tag = dst_tag;
4028
4029 /* conditions on bias memory */
4030 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
4031 if (jcp.with_bias) {
4032 if (diff_bias_d.format_kind() == format_kind::any)
4033 CHECK(memory_desc_init_by_tag(diff_bias_md, x));
4034 }
4035
4036 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
4037
4038 /* kernel applicability check wrt boundaries
4039 * the conditions are quite general across the kernels we have,
4040 * but ideally the check should belong to a specific kernel... */
4041 const int max_pad_h = ext_kh / 2;
4042 const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw
4043 && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h
4044 && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd
4045 && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad)
4046 && jcp.l_pad <= max_ur_w && jcp.r_pad <= max_ur_w;
4047 if (!boundaries_ok) return status::unimplemented;
4048
4049 /* yet another common check */
4050 if (!jcp.is_hw_transp && jcp.kw > 14) return status::unimplemented;
4051
4052 /* setting register strategy */
4053 const int unroll_dim = jcp.is_hw_transp ? jcp.oh : jcp.ow;
4054 for (int ur_w = nstl::min(max_ur_w, unroll_dim); ur_w > 0; --ur_w) {
4055 if (unroll_dim % ur_w == 0) {
4056 jcp.ur_w = ur_w;
4057 break;
4058 }
4059 }
4060
4061 if (jcp.is_1stconv) {
4062 auto src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_ncx;
4063 if (src_d.format_kind() == format_kind::any) {
4064 CHECK(memory_desc_init_by_tag(src_md, src_tag));
4065 } else {
4066 // if `ic == 1`, then `nxc` and `ncx` are effectively equivalent
4067 if (jcp.ic == 1 && one_of(curr_src_tag, dat_tag_nxc, dat_tag_ncx))
4068 src_tag = curr_src_tag;
4069 if (curr_src_tag != src_tag) return status::unimplemented;
4070 }
4071 jcp.src_tag = src_tag;
4072
4073 const bool src_ok = IMPLICATION(!is_data_layout_nxc,
4074 (one_of(jcp.ic, 1, 2, 3) && jcp.ngroups == 1));
4075 if (!src_ok) return status::unimplemented;
4076
4077 jcp.ic_block = jcp.ic;
4078
4079 wei_tag = with_groups ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
4080 : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
4081
4082 if (init_tag(jcp.wei_tag, diff_weights_md, diff_weights_d, wei_tag)
4083 != status::success)
4084 return status::unimplemented;
4085
4086 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
4087
4088 } else {
4089 auto src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
4090 if (src_md.format_kind == format_kind::any) {
4091 CHECK(memory_desc_init_by_tag(src_md, src_tag));
4092 } else if (curr_src_tag != src_tag)
4093 return status::unimplemented;
4094 jcp.src_tag = src_tag;
4095
4096 if (init_tag(jcp.wei_tag, diff_weights_md, diff_weights_d, wei_tag)
4097 != status::success)
4098 return status::unimplemented;
4099
4100 jcp.ic_block = jcp.simd_w;
4101 if (ok_to_pad_channels) jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
4102 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
4103 }
4104
4105 jcp.typesize_in = typesize;
4106 jcp.typesize_out = typesize;
4107
4108 bool use_nxc_harness = false;
4109 if (is_data_layout_nxc) {
4110 dim_t kernel_size
4111 = jcp.ic * jcp.oc * jcp.kd * jcp.kh * jcp.kw * jcp.typesize_out;
4112 dim_t src_size
4113 = jcp.mb * jcp.ic * jcp.id * jcp.ih * jcp.iw * jcp.typesize_in;
4114 dim_t diff_dst_size
4115 = jcp.mb * jcp.oc * jcp.id * jcp.ih * jcp.iw * jcp.typesize_in;
4116 dim_t data_size = src_size + diff_dst_size;
4117
4118 // The advantage of the nxc kernel is cache traversal, this comes at a
4119 // cost of extra work updating the weights buffers more often. As such,
4120 // if everything fits in cache, this kernel is at a disadvantage to the
4121 // inner loop over ow. More optimizing/balancing is required to
4122 // determine when this is needed for multidimensional kernels because
4123 // the data reuses within the kernel height/depth dimension make the
4124 // computation more computationally bound and cache traversal advantage
4125 // less important. Due to the current blocked weights format, the
4126 // weights and the data buffers cannot both be traversed optimally, so
4127 // for performance, the weights must fit in cache.
4128 const unsigned int L2_cache_size = platform::get_per_core_cache_size(2);
4129 use_nxc_harness
4130 = (data_size / nthreads + kernel_size > L2_cache_size / 3)
4131 && (jcp.oc % jcp.simd_w == 0) && (jcp.ic % jcp.simd_w == 0)
4132 && jcp.kw > 1 && ndims == 3
4133 && (kernel_size < L2_cache_size / 2);
4134 }
4135
4136 jcp.harness = use_nxc_harness
4137 ? harness_nxc
4138 : ndims == 5 ? harness_3d_reduction : harness_mb_reduction;
4139 if (jcp.dilate_h == 0 && jcp.ndims == 4 && jcp.oh > min_oh_reduce
4140 && !jcp.is_hw_transp && !is_data_layout_nxc)
4141 jcp.harness = harness_2d_reduction; // 2d harness with oh reduction
4142 bool args_ok = true
4143 && IMPLICATION(!is_data_layout_nxc,
4144 jcp.ic % jcp.ic_block == 0 && jcp.oc % jcp.oc_block == 0)
4145 && jcp.ic <= src_d.padded_dims()[1]
4146 && jcp.oc <= diff_dst_d.padded_dims()[1]
4147 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
4148 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
4149 if (!args_ok) return status::unimplemented;
4150
4151 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
4152 if (jcp.harness == harness_nxc) {
4153 // The harness_nxc is quite different from the other kernels. The
4154 // init_conf function should probably be refactored so that it calls
4155 // functions along the line of tune_nxc, tune_fma which
4156 // independently tune the kernels for each implementation with tuning
4157 // common to multiple implementations performed by helper functions.
4158 // This will help maintainability and help prevent the different
4159 // implementations from stepping on each other.
4160 int zmm_regs = 32;
4161
4162 // Block by ic and kw in the compute kernel to decrease loads from the
4163 // src buffer
4164 jcp.ur_ic = 2 - jcp.ic % 2;
4165 jcp.ur_kw = 1;
4166 if (jcp.stride_w == jcp.dilate_w + 1) {
4167 jcp.ur_kw = jcp.kw;
4168 if (jcp.kw > 7) {
4169 // Blocking by kw is more effective than by ic in the compute
4170 // kernel since neighbor kw operations share src data
4171 jcp.ur_ic = 1;
4172 if (jcp.kw > zmm_regs / (jcp.ur_ic + 1))
4173 jcp.ur_kw = jcp.kw % (zmm_regs / (jcp.ur_ic + 1));
4174 }
4175 }
4176
4177 // Unroll by ow to decrease updates to diff_weights. In practice, this
4178 // should be approximately 1/4 - 1/2 of the zmm registers
4179 jcp.ur_ow = nstl::min(
4180 (zmm_regs - jcp.ur_kw * jcp.ur_ic) / (jcp.ur_ic + 1), jcp.ow);
4181
4182 int work_amount_base = jcp.mb * jcp.od * jcp.oh;
4183 int ow_iter = div_up(jcp.ow, jcp.ur_ow);
4184 int nthr_ow = nstl::min(
4185 jcp.nthr / math::gcd(work_amount_base, jcp.nthr), ow_iter);
4186 int ow_block = div_up(ow_iter, nthr_ow) * jcp.ur_ow;
4187
4188 jcp.ow_block = ow_block;
4189 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
4190
4191 // Choose a simple parallelization method. A more advance may need made
4192 // later
4193 int work_amount = jcp.mb * jcp.od * jcp.oh * jcp.nb_ow;
4194 nthr_mb = nstl::min(jcp.nthr, work_amount);
4195 nthr_g = 1;
4196 nthr_oc_b = 1;
4197 nthr_ic_b = 1;
4198 nthr = nthr_mb * nthr_g * nthr_oc_b * nthr_ic_b;
4199 } else { // balancing
4200 balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b, jcp.nthr);
4201 }
4202
4203 jcp.nthr = nthr;
4204 jcp.nthr_mb = nthr_mb;
4205 jcp.nthr_g = nthr_g;
4206 jcp.nthr_oc_b = nthr_oc_b;
4207 jcp.nthr_ic_b = nthr_ic_b;
4208
4209 jcp.kernel_kind = embd_bcast;
4210 if (is_data_layout_nxc && jcp.stride_w == 1 && jcp.dilate_w == 0
4211 && !jcp.is_1stconv) {
4212 jcp.kernel_kind = expl_bcast;
4213 }
4214
4215 jcp.nb_ic_blocking_max = 1;
4216 if (is_data_layout_nxc && (jcp.ow > max_ur_w || jcp.ndims == 5)) {
4217 assert(!jcp.is_hw_transp);
4218 jcp.nb_ic_blocking_max = nstl::min(8, div_up(jcp.nb_ic, jcp.nthr_ic_b));
4219 }
4220
4221 return status::success;
4222}
4223
4224void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
4225 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
4226 if (jcp.nthr_mb > 1) {
4227 const auto wei_size = static_cast<size_t>(jcp.ngroups)
4228 * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block)
4229 * jcp.kh * jcp.kw * jcp.kd;
4230 const auto bia_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block);
4231 const size_t wei_bia_reduction_size = wei_size + bia_size;
4232
4233 scratchpad.book(key_conv_wei_bia_reduction,
4234 wei_bia_reduction_size * (jcp.nthr_mb - 1), jcp.typesize_out);
4235 scratchpad.book<simple_barrier::ctx_t>(
4236 key_conv_wei_bia_reduction_bctx, 1);
4237 }
4238
4239 if (jcp.with_bias && jcp.oc_without_padding % jcp.oc_block != 0) {
4240 const size_t nelems_padded_bias
4241 = jcp.ngroups * utils::rnd_up(jcp.oc, jcp.oc_block);
4242 scratchpad.book(
4243 key_conv_padded_bias, nelems_padded_bias, jcp.typesize_out);
4244 }
4245}
4246
4247void jit_avx512_common_conv_bwd_weights_kernel_f32::balance(
4248 const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
4249 int &nthr_oc_b_, int &nthr_ic_b_, int nthreads) {
4250 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
4251
4252 if (nthreads < j.ngroups) {
4253 /* simplification... fortunately it doesn't hurt much */
4254 nthr_ = nthr_g_ = nthreads;
4255 return;
4256 }
4257
4258 nthr_g_ = j.ngroups;
4259 const int nthr = nthreads / nthr_g_;
4260
4261 const int ih = j.is_hw_transp ? j.tr_ih : j.ih;
4262 const int oh = j.is_hw_transp ? j.ow : j.oh;
4263
4264 int ih_reduce = j.harness == harness_2d_reduction ? ih : 1;
4265 int oh_reduce = j.harness == harness_2d_reduction ? oh : 1;
4266 int ih_no_reduce = j.harness == harness_2d_reduction ? 1 : ih;
4267 int oh_no_reduce = j.harness == harness_2d_reduction ? 1 : oh;
4268 int nthr_oh_reduce = nstl::max(1, oh_reduce / min_oh_reduce);
4269
4270 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4271 /* calculate per thread memory cost (read/write). high level optimizer
4272 * tries to minimize memory consumption. few notes:
4273 * (n1) unclear why, but that essentially helps first convolution...
4274 * (n2) assuming the reduction over minibatch is always there:
4275 * - instead of 8 it should be 5 here (write ~= 2 read):
4276 * kernel: temporal workspace 1 write
4277 * reduction: 1 read from workspace and 1 write to the diff_wei
4278 * - but experiments showed 8 works better than 5 or 6... */
4279
4280 const dim_t src_coef = 1;
4281 const dim_t dst_coef = 1;
4282 const dim_t wei_coef = 8;
4283 const dim_t iw = j.is_hw_transp ? j.tr_iw : j.iw;
4284 const dim_t ow = j.is_hw_transp ? j.oh : j.ow;
4285
4286 return 0
4287 + src_coef * div_up(j.mb * ih_reduce, nthr_mb)
4288 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_ic, nthr_ic_b)
4289 * j.ic_block * ih_no_reduce * iw * j.id / j.stride_d
4290 / j.stride_h / j.stride_w /* (n1) */
4291 + dst_coef * div_up(j.mb * oh_reduce, nthr_mb)
4292 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_oc, nthr_oc_b)
4293 * j.oc_block * oh_no_reduce * ow * j.od
4294 + wei_coef /* (n2) */
4295 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_oc, nthr_oc_b)
4296 * div_up(j.nb_ic, nthr_ic_b) * j.kh * j.kw * j.kd * j.ic_block
4297 * j.oc_block;
4298 };
4299
4300 dim_t best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4301
4302 /* step 1: find the best thread distribution with lowest memory cost */
4303 const int nthr_mb_max = nstl::min(nthr, j.mb * j.od * nthr_oh_reduce);
4304 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4305 const int nthr_par = nthr / nthr_mb;
4306 const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4307 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4308 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4309
4310 dim_t mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4311 if (mem_cost <= best_mem_cost) {
4312 best_mem_cost = mem_cost;
4313 nthr_mb_ = nthr_mb;
4314 nthr_oc_b_ = nthr_oc_b;
4315 nthr_ic_b_ = nthr_ic_b;
4316 }
4317 }
4318 }
4319
4320 auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4321 return (dim_t)div_up(j.mb * oh_reduce, nthr_mb)
4322 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_oc, nthr_oc_b)
4323 * div_up(j.nb_ic, nthr_ic_b);
4324 };
4325
4326 /* step 2: search for a thread distribution with lower compute cost.
4327 * the constrains:
4328 * - memory cost cannot exceed 110% of the best found in the step 1
4329 * - unless compute cost is 133% lower than the current best case
4330 * note: both constants were found empirically */
4331 dim_t best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4332 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4333 const int nthr_par = nthr / nthr_mb;
4334 const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4335 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4336 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4337 dim_t mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4338 dim_t comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4339
4340 const bool opt1 = comp_cost <= best_comp_cost
4341 && IMPLICATION(
4342 !j.is_hw_transp, mem_cost < 1.1 * best_mem_cost);
4343 const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
4344
4345 if (opt1 || opt2) {
4346 best_comp_cost = comp_cost;
4347 nthr_mb_ = nthr_mb;
4348 nthr_oc_b_ = nthr_oc_b;
4349 nthr_ic_b_ = nthr_ic_b;
4350 }
4351 }
4352 }
4353
4354 if (nthr_mb_ > nthreads / 2 && nthr_mb_ < nthreads)
4355 nthr_mb_ = nstl::min(j.mb * j.od * nthr_oh_reduce, nthreads);
4356 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
4357
4358 assert(nthr_ <= nthreads);
4359}
4360
4361template struct _jit_avx512_common_conv_fwd_kernel<Zmm>;
4362template struct _jit_avx512_common_conv_fwd_kernel<Ymm>;
4363template struct _jit_avx512_common_conv_fwd_kernel<Xmm>;
4364template struct _jit_avx512_common_conv_bwd_data_kernel_f32<Zmm>;
4365template struct _jit_avx512_common_conv_bwd_data_kernel_f32<Ymm>;
4366template struct _jit_avx512_common_conv_bwd_data_kernel_f32<Xmm>;
4367
4368} // namespace x64
4369} // namespace cpu
4370} // namespace impl
4371} // namespace dnnl
4372
4373// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
4374