1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3* Copyright 2018 YANDEX LLC
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include "common/c_types_map.hpp"
19#include "common/dnnl_thread.hpp"
20#include "common/memory.hpp"
21#include "common/nstl.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/platform.hpp"
26#include "cpu/x64/injectors/injector_utils.hpp"
27#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
29#include "cpu/x64/jit_avx2_conv_kernel_f32.hpp"
30
31#define GET_OFF(field) offsetof(jit_conv_call_s, field)
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36namespace x64 {
37
38using namespace dnnl::impl::prop_kind;
39using namespace dnnl::impl::format_tag;
40using namespace dnnl::impl::memory_tracking::names;
41using namespace dnnl::impl::utils;
42
43using namespace Xbyak;
44
45jit_avx2_conv_fwd_kernel_f32::jit_avx2_conv_fwd_kernel_f32(
46 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
47 const memory_desc_t &dst_md)
48 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx2)
49 , jcp(ajcp)
50 , attr_(attr) {
51 if (jcp.with_eltwise || jcp.with_binary) {
52 using namespace binary_injector;
53 static constexpr bool preserve_gpr = true;
54 static constexpr bool preserve_vmm = false;
55 static constexpr size_t helper_vmm_idx = 15;
56 static constexpr bool use_exact_tail_scalar_bcast = false;
57 const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
58
59 rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14,
60 r15, preserve_gpr, preserve_vmm,
61 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
62 memory_desc_wrapper(dst_md), tail_size,
63 use_exact_tail_scalar_bcast};
64 static_params_t static_params {this->param1, rhs_arg_static_params};
65
66 postops_injector_ = utils::make_unique<
67 injector::jit_uni_postops_injector_t<avx2>>(
68 this, jcp.post_ops, static_params);
69 }
70}
71
72void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(
73 int ur_w, int pad_l, int pad_r, int oc_blocks) {
74 int kw = jcp.kw;
75 int stride_w = jcp.stride_w;
76 int dilate_w = jcp.dilate_w + 1;
77 int ic_block = jcp.ic_block;
78 int ic_tail = jcp.ic_tail;
79
80 for (int ki = 0; ki < kw; ki++) {
81 int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
82 int jj_end = ur_w
83 - nstl::max(0,
84 div_up(ki * dilate_w + pad_r - (kw - 1) * dilate_w,
85 stride_w));
86
87 auto compute = [=](int cur_ic_blk) {
88 for (int ifm2 = 0; ifm2 < cur_ic_blk; ifm2++) {
89 for (int jj = jj_start; jj < jj_end; jj++) {
90 size_t inp_off = get_input_offset(
91 ifm2, filter_w_to_input(ki, jj, pad_l));
92 vbroadcastss(Ymm(oc_blocks * ur_w + jj),
93 make_safe_addr(
94 aux_reg_input, inp_off, reg_long_offt));
95 }
96
97 for (int ii = 0; ii < oc_blocks; ii++) {
98 vmovups(ymm15,
99 make_safe_addr(aux_reg_kernel,
100 get_kernel_offset(ii, ki, ifm2),
101 reg_long_offt));
102 for (int jj = jj_start; jj < jj_end; jj++)
103 if (mayiuse(avx2))
104 vfmadd231ps(Ymm(ur_w * ii + jj),
105 Ymm(oc_blocks * ur_w + jj), ymm15);
106 else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
107 vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
108 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
109 ytmp);
110 }
111 }
112 }
113 };
114
115 if (ic_tail) {
116 if (jcp.ic == ic_tail)
117 compute(ic_tail);
118 else {
119 Label ic_blk_tail, ic_blk_done;
120 cmp(reg_channel, ic_block);
121 jl(ic_blk_tail, T_NEAR);
122
123 compute(ic_block);
124 jmp(ic_blk_done, T_NEAR);
125
126 L(ic_blk_tail);
127 compute(ic_tail);
128
129 L(ic_blk_done);
130 }
131 } else {
132 compute(ic_block);
133 }
134 }
135}
136
137void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(
138 int ur_w, int pad_l, int pad_r, int oc_blocks) {
139 Label kw_loop;
140
141 int kw = jcp.kw;
142 int ic_blk = jcp.ic_block;
143
144 xor_(ki_iter, ki_iter);
145 L(kw_loop);
146 {
147 int jj_start = 0;
148 int jj_end = ur_w;
149 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
150 for (int jj = jj_start; jj < jj_end; jj++) {
151 size_t inp_off = get_input_offset(
152 ifm2, filter_w_to_input(0, jj, pad_l));
153 vbroadcastss(Ymm(oc_blocks * ur_w + jj),
154 make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
155 }
156 for (int ii = 0; ii < oc_blocks; ii++) {
157 vmovups(ymm15,
158 make_safe_addr(aux_reg_kernel,
159 get_kernel_offset(ii, 0, ifm2), reg_long_offt));
160 for (int jj = jj_start; jj < jj_end; jj++)
161 if (mayiuse(avx2))
162 vfmadd231ps(Ymm(ur_w * ii + jj),
163 Ymm(oc_blocks * ur_w + jj), ymm15);
164 else { // Intel AVX support
165 vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
166 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
167 }
168 }
169 }
170 safe_add(aux_reg_kernel, get_kernel_offset(0, 1, 0), reg_long_offt);
171 safe_add(aux_reg_input, get_input_offset(0, filter_w_to_input(1)),
172 reg_long_offt);
173
174 inc(ki_iter);
175 cmp(ki_iter, kw);
176 jl(kw_loop, T_NEAR);
177 }
178}
179
180static int get_ymm_idx(
181 const int ur_w, const int oc_block_idx, const int ur_w_idx) {
182 return (ur_w * oc_block_idx + ur_w_idx);
183}
184
185static Ymm get_ymm(const int ur_w, const int oc_block_idx, const int ur_w_idx) {
186 return Ymm(get_ymm_idx(ur_w, oc_block_idx, ur_w_idx));
187}
188
189template <typename F>
190void iterate(const int load_loop_blk, const int ur, const int load_dim_tail,
191 const F &f) {
192 for (int i = 0; i < load_loop_blk; ++i) {
193 const bool mask_flag = (load_dim_tail > 0) && (i == load_loop_blk - 1);
194 for (int j = 0; j < ur; ++j)
195 f(mask_flag, i, j);
196 }
197}
198template <typename F>
199void iterate(const int load_loop_blk, const int ur, const F &f) {
200 iterate(load_loop_blk, ur, 0, f);
201}
202
203void jit_avx2_conv_fwd_kernel_f32::apply_postops(
204 const int oc_blocks, const int ur_w, const int oc_tail) {
205 if (jcp.with_eltwise || jcp.with_binary) {
206 Label regular_store;
207 test(reg_ci_flag, FLAG_IC_LAST);
208 je(regular_store, T_NEAR);
209
210 injector_utils::vmm_index_set_t vmm_idxs;
211 if (jcp.with_binary) {
212 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
213 rhs_arg_params_tail;
214 iterate(oc_blocks, ur_w, oc_tail,
215 [&](const bool mask_flag, const int i, const int j) {
216 const size_t aux_output_offset
217 = get_output_offset(i, j);
218 const auto vmm_idx = get_ymm_idx(ur_w, i, j);
219 vmm_idxs.emplace(vmm_idx);
220
221 rhs_arg_params_tail.vmm_idx_to_out_reg.emplace(
222 vmm_idx, reg_output);
223 rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
224 vmm_idx, aux_output_offset);
225 if (mask_flag)
226 rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
227 });
228 rhs_arg_params = rhs_arg_params_tail;
229 rhs_arg_params.vmm_tail_idx_.clear();
230
231 Label postops_done;
232 if (oc_tail) {
233 Label postops_no_tail;
234 test(reg_oc_flag, FLAG_OC_LAST);
235 je(postops_no_tail, T_NEAR);
236 postops_injector_->compute_vector_range(
237 vmm_idxs, rhs_arg_params_tail);
238 jmp(postops_done, T_NEAR);
239 L(postops_no_tail);
240 }
241 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
242 L(postops_done);
243
244 } else {
245 iterate(oc_blocks, ur_w, [&](const bool, const int i, const int j) {
246 vmm_idxs.emplace(get_ymm_idx(ur_w, i, j));
247 });
248 postops_injector_->compute_vector_range(vmm_idxs);
249 }
250 L(regular_store);
251 }
252}
253
254void jit_avx2_conv_fwd_kernel_f32::width_blk_step(
255 int ur_w, int pad_l, int pad_r, int oc_blocks) {
256 int kw = jcp.kw;
257 int oc_blk = jcp.oc_block;
258 int oc_tail = jcp.oc_tail;
259
260 if (oc_tail) {
261 push(reg_oc_blocks);
262 mov(reg_oc_flag, ptr[param1 + GET_OFF(oc_flag)]);
263 }
264
265 auto load_output_bias_and_add_bias = [=](bool is_tail) {
266 Label init_done, init_first;
267
268 if (!jcp.with_sum) {
269 test(reg_ci_flag, FLAG_IC_FIRST);
270 jne(init_first, T_NEAR);
271 }
272
273 for (int ii = 0; ii < oc_blocks; ii++)
274 for (int jj = 0; jj < ur_w; jj++) {
275 const auto ymm = get_ymm(ur_w, ii, jj);
276 if (is_tail && ii == oc_blocks - 1)
277 load_bytes(ymm, reg_output, get_output_offset(ii, jj),
278 oc_tail * sizeof(float));
279 else
280 vmovups(ymm,
281 make_safe_addr(reg_output,
282 get_output_offset(ii, jj), reg_long_offt));
283 }
284
285 if (jcp.with_sum && jcp.with_bias) {
286 test(reg_ci_flag, FLAG_IC_FIRST);
287 je(init_done, T_NEAR);
288
289 for (int ii = 0; ii < oc_blocks; ii++)
290 for (int jj = 0; jj < ur_w; jj++) {
291 const Ymm ymm = get_ymm(ur_w, ii, jj);
292 if (is_tail && ii == oc_blocks - 1) {
293 load_bytes(ytmp, reg_bias, sizeof(float) * ii * oc_blk,
294 oc_tail * sizeof(float));
295 vaddps(ymm, ymm, ytmp);
296 } else {
297 vaddps(ymm, ymm,
298 yword[reg_bias + sizeof(float) * ii * oc_blk]);
299 }
300 }
301 }
302 jmp(init_done, T_NEAR);
303
304 L(init_first);
305
306 if (jcp.with_bias) {
307 for (int ii = 0; ii < oc_blocks; ii++)
308 for (int jj = 0; jj < ur_w; jj++) {
309 const Ymm ymm = get_ymm(ur_w, ii, jj);
310 if (is_tail && ii == oc_blocks - 1)
311 load_bytes(ymm, reg_bias, sizeof(float) * ii * oc_blk,
312 oc_tail * sizeof(float));
313 else
314 vmovups(ymm,
315 yword[reg_bias + sizeof(float) * ii * oc_blk]);
316 }
317 } else {
318 for (int ii = 0; ii < oc_blocks; ii++)
319 for (int jj = 0; jj < ur_w; jj++) {
320 const Ymm ymm = get_ymm(ur_w, ii, jj);
321 uni_vpxor(ymm, ymm, ymm);
322 }
323 }
324 L(init_done);
325 };
326
327 if (oc_tail) {
328 if (jcp.nb_oc > jcp.nb_oc_blocking) {
329 Label load_tail, load_done;
330 test(reg_oc_flag, FLAG_OC_LAST);
331 jne(load_tail, T_NEAR);
332
333 load_output_bias_and_add_bias(false);
334 jmp(load_done, T_NEAR);
335
336 L(load_tail);
337 load_output_bias_and_add_bias(true);
338
339 L(load_done);
340 } else {
341 load_output_bias_and_add_bias(true);
342 }
343 } else {
344 load_output_bias_and_add_bias(false);
345 }
346
347 if (one_of(jcp.ndims, 3, 4)) {
348 mov(aux_reg_input, reg_input);
349 mov(aux_reg_kernel, reg_kernel);
350 }
351
352 Label skip_kh_loop, skip_kd_loop, kd_loop;
353 if (jcp.ndims == 5) {
354 push(reg_output);
355 push(oi_iter);
356
357 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
358 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
359 mov(aux_reg_inp_d, reg_input);
360
361 if ((jcp.dilate_d >= jcp.id)
362 || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
363 cmp(reg_ki, 0);
364 je(skip_kd_loop, T_NEAR);
365 }
366 L(kd_loop);
367 mov(kj, ptr[param1 + GET_OFF(kh_padding)]);
368 } else {
369 mov(kj, reg_kh);
370 }
371
372 if (jcp.ndims == 5) {
373 mov(aux_reg_input, aux_reg_inp_d);
374 mov(aux_reg_kernel, aux_reg_ker_d);
375 }
376
377 if ((jcp.dilate_h >= jcp.ih)
378 || (jcp.kh - 1) * (jcp.dilate_h + 1)
379 < nstl::max(jcp.t_pad, jcp.b_pad)) {
380 cmp(kj, 0);
381 je(skip_kh_loop, T_NEAR);
382 }
383 Label kh_loop;
384 L(kh_loop);
385 {
386 if ((jcp.ic % jcp.ic_block == 0) && jcp.kw >= 5 && pad_l == 0
387 && pad_r == 0) {
388 oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks);
389 add(aux_reg_input,
390 get_input_offset(0, filter_h_to_input(1))
391 - get_input_offset(0, filter_w_to_input(kw)));
392 } else {
393 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
394 safe_add(
395 aux_reg_kernel, get_kernel_offset(0, kw, 0), reg_long_offt);
396 safe_add(aux_reg_input, get_input_offset(0, filter_h_to_input(1)),
397 reg_long_offt);
398 }
399
400 dec(kj);
401 cmp(kj, 0);
402 jg(kh_loop, T_NEAR);
403 }
404
405 L(skip_kh_loop);
406
407 if (jcp.ndims == 5) {
408 safe_add(aux_reg_inp_d, get_input_offset(0, filter_d_to_input(1)),
409 reg_long_offt);
410 safe_add(aux_reg_ker_d, get_kernel_offset(0, jcp.kw * jcp.kh, 0),
411 reg_long_offt);
412
413 dec(reg_ki);
414 cmp(reg_ki, 0);
415 jg(kd_loop, T_NEAR);
416 L(skip_kd_loop);
417
418 pop(oi_iter);
419 pop(reg_output);
420 }
421
422 apply_postops(oc_blocks, ur_w, oc_tail);
423
424 auto store_output = [=](bool is_tail, int tail) {
425 const auto is_padding = jcp.oc_without_padding != jcp.oc;
426 if (is_padding) uni_vxorps(ytmp, ytmp, ytmp);
427 for (int ii = 0; ii < oc_blocks; ii++)
428 for (int jj = 0; jj < ur_w; jj++) {
429 Ymm reg_out = get_ymm(ur_w, ii, jj);
430 if (is_tail && ii == oc_blocks - 1) {
431 if (is_padding && jcp.with_binary) {
432 vmovups(make_safe_addr(reg_output,
433 get_output_offset(ii, jj),
434 reg_long_offt),
435 ytmp);
436 }
437 store_bytes(reg_out, reg_output, get_output_offset(ii, jj),
438 tail * sizeof(float));
439 } else
440 vmovups(make_safe_addr(reg_output,
441 get_output_offset(ii, jj), reg_long_offt),
442 reg_out);
443 }
444 };
445
446 if (oc_tail) {
447 if (jcp.nb_oc > jcp.nb_oc_blocking) {
448 Label store_tail, store_done;
449 test(reg_oc_flag, FLAG_OC_LAST);
450 jne(store_tail, T_NEAR);
451
452 store_output(false, oc_tail);
453 jmp(store_done, T_NEAR);
454
455 L(store_tail);
456 store_output(true, oc_tail);
457
458 L(store_done);
459 } else {
460 store_output(true, oc_tail);
461 }
462 } else {
463 Label regular_store;
464 Label store_done;
465 const int tail = jcp.oc_without_padding % jcp.oc_block;
466 if (jcp.with_binary && tail) {
467 test(reg_ci_flag, FLAG_IC_LAST);
468 je(regular_store, T_NEAR);
469 if (!oc_tail) mov(reg_oc_flag, ptr[param1 + GET_OFF(oc_flag)]);
470 test(reg_oc_flag, FLAG_OC_LAST);
471 je(regular_store, T_NEAR);
472 store_output(true, tail);
473 jmp(store_done, T_NEAR);
474 }
475
476 L(regular_store);
477 store_output(false, oc_tail);
478
479 L(store_done);
480 }
481
482 if (oc_tail) pop(reg_oc_blocks);
483}
484
485inline void jit_avx2_conv_fwd_kernel_f32::solve_common(int oc_blocks) {
486 int ur_w = jcp.ur_w;
487 int ur_w_tail = jcp.ur_w_tail;
488 int n_oi = jcp.ow / ur_w;
489 int iw = jcp.iw;
490 int kw = jcp.kw;
491 int str_w = jcp.stride_w;
492
493 int l_pad = jcp.l_pad;
494 int r_pad = nstl::max(0, jcp.r_pad);
495 int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, str_w,
496 calculate_extended_filter_size(kw, jcp.dilate_w));
497 if (r_pad1 > 0) n_oi--;
498
499 if (l_pad > 0) {
500 n_oi--;
501 if (n_oi < 0 && r_pad1 > 0)
502 width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad"
503 else
504 width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad"
505 add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w, l_pad)));
506 add(reg_output, get_output_offset(0, ur_w));
507 }
508
509 Label ow_loop;
510 xor_(oi_iter, oi_iter);
511
512 if (n_oi > 0) {
513 L(ow_loop);
514
515 width_blk_step(ur_w, 0, 0, oc_blocks); // "middle"
516 add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w)));
517 add(reg_output, get_output_offset(0, ur_w));
518
519 inc(oi_iter);
520 cmp(oi_iter, n_oi);
521 jl(ow_loop, T_NEAR);
522 }
523
524 if (r_pad1 > 0 && n_oi >= 0) {
525 width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad"
526 add(reg_input, get_input_offset(0, filter_w_to_input(0, ur_w)));
527 add(reg_output, get_output_offset(0, ur_w));
528 }
529
530 if (ur_w_tail != 0)
531 width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail"
532}
533
534void jit_avx2_conv_fwd_kernel_f32::generate() {
535 this->preamble();
536
537 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
538 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
539 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
540 if (jcp.with_bias) mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
541 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
542 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
543 mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
544
545 if (is_src_layout_nxc())
546 mov(reg_channel, ptr[param1 + GET_OFF(reduce_work)]);
547
548 int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
549
550 Label tail, exit;
551
552 if (jcp.nb_oc > jcp.nb_oc_blocking) {
553 cmp(reg_oc_blocks, jcp.nb_oc_blocking);
554 jne(nb_oc_tail ? tail : exit, T_NEAR);
555
556 solve_common(jcp.nb_oc_blocking);
557 jmp(exit, T_NEAR);
558
559 if (nb_oc_tail) {
560 L(tail);
561 cmp(reg_oc_blocks, nb_oc_tail);
562 jne(exit, T_NEAR);
563 solve_common(nb_oc_tail);
564 }
565
566 L(exit);
567 } else if (jcp.nb_oc == jcp.nb_oc_blocking) {
568 solve_common(jcp.nb_oc_blocking);
569 } else {
570 solve_common(nb_oc_tail);
571 }
572
573 this->postamble();
574
575 if (jcp.with_eltwise) postops_injector_->prepare_table();
576}
577
578status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
579 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
580 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
581 const primitive_attr_t &attr) {
582 if (!mayiuse(avx)) return status::unimplemented;
583 jcp.isa = mayiuse(avx2) ? avx2 : avx;
584
585 jcp.nthr = dnnl_get_max_threads();
586
587 jcp.prop_kind = cd.prop_kind;
588
589 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
590 int ndims = src_d.ndims();
591 jcp.ndims = ndims;
592
593 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
594 jcp.mb = src_d.dims()[0];
595
596 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
597 jcp.oc_without_padding = jcp.oc;
598 jcp.ic = src_d.dims()[1] / jcp.ngroups;
599
600 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
601 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
602 jcp.iw = src_d.dims()[ndims - 1];
603 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
604 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
605 jcp.ow = dst_d.dims()[ndims - 1];
606 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
607 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
608 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
609
610 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
611 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
612 jcp.l_pad = cd.padding[0][ndims - 3];
613 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
614 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
615 jcp.stride_w = cd.strides[ndims - 3];
616
617 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
618 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
619 jcp.dilate_w = cd.dilates[ndims - 3];
620
621 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
622 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
623 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
624 jcp.r_pad = calculate_end_padding(
625 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
626 jcp.b_pad = calculate_end_padding(
627 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
628 jcp.back_pad = calculate_end_padding(
629 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
630 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
631 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
632 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
633 if (kernel_outside_src) return status::unimplemented;
634
635 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
636 const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
637 const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
638 auto wei_tag_OIxio = with_groups
639 ? pick(ndims - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o)
640 : pick(ndims - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o);
641 auto wei_tag_Oxio = with_groups ? pick(ndims - 3, gOwi8o, gOhwi8o, gOdhwi8o)
642 : pick(ndims - 3, Owi8o, Ohwi8o, Odhwi8o);
643
644 jcp.src_tag
645 = src_d.matches_one_of_tag(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c);
646 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag_OIxio, wei_tag_Oxio);
647 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
648
649 jcp.typesize_in = types::data_type_size(src_d.data_type());
650 jcp.typesize_out = types::data_type_size(dst_d.data_type());
651
652 bool is_data_layout_nxc
653 = everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
654
655 // Disable this kernel on high width 1d object as gemm performs better until
656 // optimizations can be made to fix it.
657 if (is_data_layout_nxc && ndims == 3 && jcp.ow > 11 * 1024)
658 return status::unimplemented;
659
660 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
661
662 const auto &post_ops = attr.post_ops_;
663
664 jcp.with_sum = post_ops.find(primitive_kind::sum) != -1;
665 const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
666 jcp.with_eltwise = eltwise_ind != -1;
667 const int binary_ind = post_ops.find(primitive_kind::binary);
668 jcp.with_binary = binary_ind != -1;
669
670 jcp.post_ops = post_ops;
671
672 const int simd_w = 8;
673 const bool flat = jcp.ic < simd_w;
674 const bool mimo = !flat;
675
676 /* Grouped channel offset to support 'non-blocked data' format for
677 * convolution sizes with '(input_channel / ngroups) < simd' */
678 jcp.nonblk_group_off
679 = one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic
680 : 1;
681
682 bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1;
683
684 if (ok_to_pad_channels) {
685 jcp.oc = rnd_up(jcp.oc, simd_w);
686 if (mimo) jcp.ic = rnd_up(jcp.ic, simd_w);
687 }
688
689 if (jcp.with_eltwise || jcp.with_binary)
690 if (!mayiuse(avx2)) return status::unimplemented;
691
692 using namespace injector;
693 static constexpr bool sum_at_pos_0_only = true;
694 static constexpr bool sum_requires_scale_one = true;
695 static constexpr bool sum_requires_zp_zero = true;
696 const bool post_ops_ok_ = post_ops_ok({avx2, {eltwise, binary, sum},
697 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
698 sum_requires_zp_zero});
699 if (!post_ops_ok_) return status::unimplemented;
700
701 bool args_ok = true
702 && IMPLICATION(flat,
703 jcp.wei_tag == wei_tag_Oxio
704 && ((jcp.src_tag == dat_tag_ncx
705 && jcp.dst_tag == dat_tag_nCx8c)
706 || (jcp.src_tag == dat_tag_nxc
707 && jcp.dst_tag == dat_tag_nxc)))
708 && IMPLICATION(mimo,
709 jcp.wei_tag == wei_tag_OIxio
710 && ((jcp.src_tag == dat_tag_nCx8c
711 && jcp.dst_tag == dat_tag_nCx8c)
712 || (jcp.src_tag == dat_tag_nxc
713 && jcp.dst_tag == dat_tag_nxc)))
714 && jcp.ic <= src_d.padded_dims()[1]
715 && jcp.oc <= dst_d.padded_dims()[1];
716 if (!args_ok) return status::unimplemented;
717
718 jcp.ur_h = 1; /* no code-unrolling by h so far */
719 jcp.ur_w = 3;
720
721 jcp.oc_block = simd_w;
722 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
723
724 jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
725
726 // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively
727 // Thus, we can only assign 14 or 15 YMMs for data storage
728 const int num_avail_regs = mayiuse(avx2) ? 15 : 14;
729 if (!mayiuse(avx2)) {
730 if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) {
731 // current register assignment requires more YMMs than available
732 // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad
733 if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1)
734 jcp.ur_w -= 1;
735 else {
736 for (int b = 3; b > 1; b--) {
737 if (jcp.nb_oc % b == 0) {
738 jcp.nb_oc_blocking = b;
739 break;
740 }
741 }
742 if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) {
743 // No optimal size for 'nb_oc_blocking' with regards to
744 // 'nb_oc', default to only unroll by 'ur_w'.
745 jcp.nb_oc_blocking = 1;
746 }
747 }
748 }
749 }
750
751 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
752 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
753
754 args_ok = true && IMPLICATION(!is_data_layout_nxc, jcp.oc % simd_w == 0)
755 && jcp.l_pad <= jcp.ur_w
756 && IMPLICATION(jcp.kw > 7,
757 (jcp.t_pad == 0 && jcp.l_pad == 0)
758 || (jcp.stride_w == 1 && jcp.stride_h == 1))
759 && IMPLICATION(mimo && !is_data_layout_nxc, jcp.ic % simd_w == 0);
760 if (!args_ok) return status::unimplemented;
761
762 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % simd_w : 0;
763 jcp.oc_tail = is_data_layout_nxc
764 ? jcp.oc % simd_w
765 : (jcp.with_binary ? jcp.oc_without_padding % simd_w : 0);
766
767 int r_pad_no_tail = nstl::max(0,
768 calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
769 jcp.stride_w, ext_kw));
770
771 if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
772 /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
773 jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
774 nstl::min(jcp.ow, num_avail_regs / 2));
775 jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
776 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
777 /* check again ... */
778 r_pad_no_tail = nstl::max(0,
779 calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
780 jcp.stride_w, ext_kw));
781 if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
782 return status::unimplemented;
783 }
784 assert(jcp.nb_oc_blocking > 0);
785 assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
786
787 jcp.ic_block = flat ? jcp.ic : simd_w;
788 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
789
790 jcp.nb_ic_blocking = 12;
791 jcp.nb_ic_blocking_max = 16;
792
793 /* adjust the thread decomposition
794 * to improve the perf for small problem size
795 * the threshold L1_cache_size is empirical
796 * simply set the thread as 4 for now
797 * TODO: Add get_thr_eff func to get the optimal thread number*/
798 size_t wei_size = (size_t)sizeof(float) * jcp.ic * jcp.oc * jcp.kh * jcp.kw
799 * jcp.kd;
800 size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih
801 * jcp.iw * jcp.id;
802 size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh
803 * jcp.ow * jcp.od;
804 size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size);
805
806 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
807
808 if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size) {
809 jcp.nthr = nstl::min(jcp.nthr, 4);
810 }
811
812 return status::success;
813}
814
815void jit_avx2_conv_fwd_kernel_f32::init_scratchpad(
816 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
817 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
818 scratchpad.book<float>(key_conv_padded_bias, jcp.oc);
819}
820
821void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(
822 int ur_w, int l_overflow, int r_overflow) {
823 int kw = jcp.kw;
824 int ow = jcp.ow;
825
826 int oc_block = jcp.oc_block;
827 int nb_ic_block = jcp.nb_ic_blocking;
828 int stride_w = jcp.stride_w;
829 int stride_h = jcp.stride_h;
830 int oc_tail = jcp.oc_tail;
831 int ic_tail = jcp.ic_tail;
832
833 Label kd_loop, skip_kd_loop;
834 Label oc_loop, skip_oc_loop;
835
836 for (int ii = 0; ii < nb_ic_block; ii++)
837 for (int jj = 0; jj < ur_w; jj++) {
838 uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
839 Ymm(ur_w * ii + jj));
840 }
841
842 if (oc_tail) {
843 push(reg_long_offt);
844 mov(reg_reduce_work, ptr[param1 + GET_OFF(reduce_work)]);
845 }
846
847 if (one_of(jcp.ndims, 3, 4)) {
848 cmp(reg_channel_work, 0);
849 jle(skip_oc_loop, T_NEAR);
850 xor_(reg_channel, reg_channel);
851
852 mov(aux_reg_ddst_oc_loop, reg_ddst);
853 mov(aux_reg_kernel_oc_loop, reg_kernel);
854
855 L(oc_loop);
856 mov(aux_reg_ddst, aux_reg_ddst_oc_loop);
857 mov(aux_reg_kernel, aux_reg_kernel_oc_loop);
858 }
859
860 if (jcp.ndims == 5) {
861 assert(jcp.nb_oc_blocking == 1);
862 push(oi_iter);
863
864 mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]);
865 cmp(reg_ki, 0);
866 jle(skip_kd_loop, T_NEAR);
867
868 mov(aux_reg_dst_d, reg_ddst);
869 mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]);
870
871 L(kd_loop);
872 mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]);
873 } else {
874 mov(kj, reg_kh);
875 }
876
877 if (jcp.ndims == 5) {
878 mov(aux_reg_ddst, aux_reg_dst_d);
879 mov(aux_reg_kernel, aux_reg_ker_d);
880 }
881
882 Label kh_loop, skip_kh_loop;
883 cmp(kj, 0);
884 jle(skip_kh_loop, T_NEAR);
885
886 L(kh_loop);
887 {
888 for (int ki = 0; ki < kw; ki++) {
889 int jj_start = get_iw_start(ki, l_overflow); // 0;
890 int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w;
891
892 auto compute = [=](int cur_oc_blk) {
893 for (int ofm2 = 0; ofm2 < cur_oc_blk; ofm2++) {
894 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
895 int aux_output_offset = get_ddst_offset(
896 0, filter_w_to_ddst(ki, jj, jcp.l_pad), ofm2);
897 vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w),
898 ptr[aux_reg_ddst + aux_output_offset]);
899 }
900
901 for (int ii = 0; ii < nb_ic_block; ii++) {
902 vmovups(ymm15,
903 ptr[aux_reg_kernel
904 + get_kernel_offset(0, ii, ki, ofm2)]);
905 for (int jj = jj_start; jj < jj_end; jj += stride_w)
906 vfmadd231ps(Ymm(ur_w * ii + jj),
907 Ymm(nb_ic_block * ur_w + jj / stride_w),
908 ymm15);
909 }
910 }
911 };
912
913 if (oc_tail) {
914 if (jcp.oc == oc_tail)
915 compute(oc_tail);
916 else {
917 Label oc_blk_tail, oc_blk_done;
918 cmp(reg_reduce_work, oc_block);
919 jl(oc_blk_tail, T_NEAR);
920 compute(oc_block);
921 jmp(oc_blk_done, T_NEAR);
922
923 L(oc_blk_tail);
924 compute(oc_tail);
925
926 L(oc_blk_done);
927 }
928 } else {
929 compute(oc_block);
930 }
931 }
932
933 add(aux_reg_kernel, get_kernel_offset(0, 0, stride_h * kw, 0));
934 sub(aux_reg_ddst, get_ddst_offset(0, (jcp.dilate_h + 1) * ow, 0));
935
936 dec(kj);
937 cmp(kj, 0);
938 jg(kh_loop, T_NEAR);
939 }
940 L(skip_kh_loop);
941
942 if (jcp.ndims == 5) {
943 sub(aux_reg_dst_d,
944 get_ddst_offset(0, (jcp.dilate_d + 1) * jcp.oh * ow, 0));
945 add(aux_reg_ker_d, get_kernel_offset(0, 0, jcp.kw * jcp.kh, 0));
946
947 dec(reg_ki);
948 cmp(reg_ki, 0);
949 jg(kd_loop, T_NEAR);
950 L(skip_kd_loop);
951
952 pop(oi_iter);
953 }
954
955 if (one_of(jcp.ndims, 3, 4)) {
956 int ddst_oc_shift = get_ddst_offset(1, 0, 0);
957 int kernel_oc_shift = get_kernel_offset(1, 0, 0, 0);
958
959 add(aux_reg_ddst_oc_loop, ddst_oc_shift);
960 add(aux_reg_kernel_oc_loop, kernel_oc_shift);
961
962 if (oc_tail) sub(reg_reduce_work, jcp.oc_block);
963 inc(reg_channel);
964 cmp(reg_channel, reg_channel_work);
965 jl(oc_loop, T_NEAR);
966
967 L(skip_oc_loop);
968 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
969 }
970
971 if (oc_tail) pop(reg_long_offt);
972
973 auto load_store_dsrc = [=](bool is_tail) {
974 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
975 Label no_update_label;
976 cmp(reg_channel, 0);
977 je(no_update_label, T_NEAR);
978
979 for (int ii = 0; ii < nb_ic_block; ii++)
980 for (int jj = 0; jj < ur_w; jj++) {
981 if (is_tail && ii == nb_ic_block - 1)
982 load_bytes(Ymm(15), reg_dsrc, get_dsrc_offset(ii, jj),
983 ic_tail * sizeof(float));
984 else
985 vmovups(Ymm(15),
986 make_safe_addr(reg_dsrc, get_dsrc_offset(ii, jj),
987 reg_long_offt));
988 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(15));
989 }
990
991 L(no_update_label);
992
993 for (int ii = 0; ii < nb_ic_block; ii++)
994 for (int jj = 0; jj < ur_w; jj++) {
995 if (is_tail && ii == nb_ic_block - 1)
996 store_bytes(Ymm(ur_w * ii + jj), reg_dsrc,
997 get_dsrc_offset(ii, jj), ic_tail * sizeof(float));
998 else
999 vmovups(make_safe_addr(reg_dsrc, get_dsrc_offset(ii, jj),
1000 reg_long_offt),
1001 Ymm(ur_w * ii + jj));
1002 }
1003 };
1004
1005 if (ic_tail) {
1006 Label load_store_tail, load_store_done;
1007 mov(reg_ci_flag, ptr[param1 + GET_OFF(flags)]);
1008 test(reg_ci_flag, FLAG_IC_LAST);
1009 jne(load_store_tail, T_NEAR);
1010
1011 load_store_dsrc(false);
1012 jmp(load_store_done, T_NEAR);
1013
1014 L(load_store_tail);
1015 load_store_dsrc(true);
1016
1017 L(load_store_done);
1018 } else {
1019 load_store_dsrc(false);
1020 }
1021}
1022
1023void jit_avx2_conv_bwd_data_kernel_f32::generate() {
1024 preamble();
1025
1026 mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
1027 mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
1028 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
1029 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
1030 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
1031 mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]);
1032
1033 int ddst_shift = get_ddst_offset(0, filter_w_to_ddst(0, jcp.ur_w), 0);
1034 int dsrc_shift = get_dsrc_offset(0, jcp.ur_w);
1035
1036 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1037
1038 int l_overflow = nstl::max(0, (ext_kw - 1 - jcp.l_pad) / jcp.stride_w);
1039 int r_overflow = nstl::max(
1040 0, (ext_kw - 1 - nstl::max(0, jcp.r_pad)) / jcp.stride_w);
1041 int r_overflow1 = nstl::max(
1042 0, (ext_kw - 1 - jcp.r_pad - jcp.ur_w_tail) / jcp.stride_w);
1043
1044 int n_oi = jcp.iw / jcp.ur_w;
1045 if (r_overflow1 > 0) n_oi--;
1046
1047 if (jcp.ur_w == jcp.iw) {
1048 compute_loop(jcp.ur_w, l_overflow, r_overflow);
1049 } else if (n_oi == 0) {
1050 compute_loop(jcp.ur_w, l_overflow, r_overflow1);
1051 add(reg_dsrc, dsrc_shift);
1052 add(reg_ddst, ddst_shift);
1053 if (jcp.ur_w_tail != 0) compute_loop(jcp.ur_w_tail, 0, r_overflow);
1054 } else {
1055 xor_(oi_iter, oi_iter);
1056 if (l_overflow > 0) {
1057 compute_loop(jcp.ur_w, l_overflow, 0);
1058 add(reg_dsrc, dsrc_shift);
1059 add(reg_ddst, ddst_shift);
1060 inc(oi_iter);
1061 }
1062
1063 if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) {
1064 Label ow_loop;
1065 L(ow_loop);
1066 {
1067 compute_loop(jcp.ur_w, 0, 0);
1068 add(reg_dsrc, dsrc_shift);
1069 add(reg_ddst, ddst_shift);
1070 inc(oi_iter);
1071 cmp(oi_iter, n_oi);
1072 jl(ow_loop, T_NEAR);
1073 }
1074 }
1075
1076 if (r_overflow1 > 0) {
1077 compute_loop(jcp.ur_w, 0, r_overflow1);
1078 add(reg_dsrc, dsrc_shift);
1079 add(reg_ddst, ddst_shift);
1080 }
1081
1082 if (jcp.ur_w_tail != 0) compute_loop(jcp.ur_w_tail, 0, r_overflow);
1083 }
1084
1085 this->postamble();
1086}
1087
1088status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
1089 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
1090 const memory_desc_wrapper &weights_d,
1091 const memory_desc_wrapper &diff_dst_d) {
1092 if (!mayiuse(avx2)) return status::unimplemented;
1093
1094 jcp.nthr = dnnl_get_max_threads();
1095
1096 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
1097
1098 int ndims = diff_src_d.ndims();
1099 jcp.ndims = ndims;
1100
1101 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1102 jcp.mb = diff_src_d.dims()[0];
1103
1104 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1105 jcp.oc_without_padding = jcp.oc;
1106 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
1107
1108 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
1109 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims - 2];
1110 jcp.iw = diff_src_d.dims()[ndims - 1];
1111 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1112 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
1113 jcp.ow = diff_dst_d.dims()[ndims - 1];
1114
1115 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1116 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1117 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1118
1119 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1120 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1121 jcp.l_pad = cd.padding[0][ndims - 3];
1122
1123 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1124 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1125 jcp.stride_w = cd.strides[ndims - 3];
1126
1127 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1128 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1129 jcp.dilate_w = cd.dilates[ndims - 3];
1130
1131 if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
1132 || (jcp.dilate_d != 0 && jcp.stride_d != 1)
1133 || (jcp.dilate_h != 0 && jcp.stride_h != 1))
1134 return status::unimplemented;
1135
1136 const int simd_w = 8;
1137
1138 /* derivatives */
1139 jcp.idp = jcp.id + 2 * jcp.f_pad;
1140 jcp.ihp = jcp.ih + 2 * jcp.t_pad;
1141 jcp.iwp = jcp.iw + 2 * jcp.l_pad;
1142 jcp.ohp = jcp.oh; /* do we really need */
1143 jcp.owp = jcp.ow; /* padded output ??? */
1144
1145 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
1146 const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
1147 auto wei_tag = with_groups
1148 ? pick(ndims - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i)
1149 : pick(ndims - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i);
1150
1151 jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
1152 jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
1153 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
1154
1155 jcp.typesize_in = types::data_type_size(diff_src_d.data_type());
1156 jcp.typesize_out = types::data_type_size(diff_dst_d.data_type());
1157
1158 bool is_data_layout_nxc
1159 = everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
1160 bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1;
1161
1162 /* gemm-based convolution performs better in these cases */
1163 if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1)
1164 return status::unimplemented;
1165
1166 if (ok_to_pad_channels) {
1167 jcp.oc = rnd_up(jcp.oc, simd_w);
1168 jcp.ic = rnd_up(jcp.ic, simd_w);
1169 }
1170
1171 jcp.ic_block = (!is_data_layout_nxc && jcp.ic % simd_w) ? 1 : simd_w;
1172 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
1173
1174 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % simd_w : 0;
1175 jcp.oc_tail = is_data_layout_nxc ? jcp.oc % simd_w : 0;
1176
1177 jcp.oc_block = simd_w;
1178 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
1179
1180 jcp.ur_h = 1; /* no code-unrolling by h so far */
1181 jcp.nb_ic_blocking = 1;
1182 jcp.nb_oc_blocking = 1;
1183 jcp.ur_w = 1;
1184
1185 if (one_of(ndims, 3, 4) && jcp.ow < 40)
1186 jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2;
1187
1188 auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
1189
1190 bool args_ok = true && jcp.stride_w == jcp.stride_h && jcp.stride_d == 1
1191 && IMPLICATION(!is_data_layout_nxc,
1192 jcp.ic % simd_w == 0 && jcp.oc % simd_w == 0)
1193 && jcp.ic <= diff_src_d.padded_dims()[1]
1194 && jcp.oc <= diff_dst_d.padded_dims()[1]
1195 && jcp.dst_tag == required_dat_tag
1196 && jcp.src_tag == required_dat_tag && jcp.wei_tag == wei_tag;
1197 if (!args_ok) return status::unimplemented;
1198
1199 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1200 const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1201 const int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1202
1203 jcp.r_pad = calculate_end_padding(
1204 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
1205 jcp.b_pad = calculate_end_padding(
1206 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
1207 jcp.back_pad = calculate_end_padding(
1208 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
1209
1210 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
1211 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
1212 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
1213 if (kernel_outside_src) return status::unimplemented;
1214
1215 int l_overflow = nstl::max(0, (ext_kw - 1 - jcp.l_pad) / jcp.stride_w);
1216
1217 const int max_regs = 15; /* Maximum number of registers available for
1218 result accumulation and delta dst data.
1219 One additional register is reserved for weights
1220 data. */
1221
1222 /* Find the best blocking with maximum number of fma instructions
1223 per ur_w * nb_ic_blocking compute loops. Number of required registers
1224 is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
1225 ur_w must be divisible by stride_w */
1226 if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
1227 distribution exceeds max_regs */
1228 return status::unimplemented;
1229
1230 int best_nfmas = 0;
1231 for (int b = 1; b <= 4; b++) {
1232 if (jcp.nb_ic % b != 0) continue;
1233
1234 for (int u = jcp.stride_w; u * b + u / jcp.stride_w <= max_regs
1235 && u < jcp.iw + jcp.stride_w;
1236 u += jcp.stride_w) {
1237 int ur_w = nstl::min(u, jcp.iw);
1238 /* maximum 1 step with l_overflow so far */
1239 if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw) continue;
1240 int nfmas = div_up(ur_w, jcp.stride_w) * b;
1241 if (nfmas > best_nfmas
1242 || (nfmas == best_nfmas && jcp.ur_w < ur_w)) {
1243 jcp.ur_w = ur_w;
1244 jcp.nb_ic_blocking = b;
1245 best_nfmas = nfmas;
1246 }
1247 }
1248 }
1249 if (best_nfmas == 0) /* can't find appropriate blocking */
1250 return status::unimplemented;
1251
1252 jcp.ur_w_tail = jcp.iw % jcp.ur_w;
1253
1254 int r_overflow_no_tail = nstl::max(
1255 0, (ext_kw - 1 - jcp.r_pad - jcp.ur_w_tail) / jcp.stride_w);
1256
1257 bool tails_not_ok = false
1258 /* maximum 1 ur_w block with r_overflow so far */
1259 || r_overflow_no_tail * jcp.stride_w > jcp.ur_w
1260 /* ur_w must be a multiple of stride */
1261 || ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1262 /* r_pad must not extend beyond ur_w_tail */
1263 || ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0));
1264 if (tails_not_ok) return status::unimplemented;
1265
1266 /* adjust the thread decomposition
1267 * to improve the perf for small problem size
1268 * the threshold L1_cache_size is empirical
1269 * simply set the thread to 4 for now
1270 * TODO: Add get_thr_eff func to get optimal thread number */
1271 size_t wei_size = (size_t)sizeof(float) * jcp.ic * jcp.oc * jcp.kh * jcp.kw
1272 * jcp.kd;
1273 size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih
1274 * jcp.iw * jcp.id;
1275 size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh
1276 * jcp.ow * jcp.od;
1277 size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size);
1278 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1279
1280 if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size) {
1281 jcp.nthr = nstl::min(jcp.nthr, 4);
1282 }
1283
1284 return status::success;
1285}
1286
1287void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(
1288 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1289 UNUSED(scratchpad);
1290 UNUSED(jcp);
1291}
1292
1293void jit_avx2_conv_bwd_weights_kernel_f32::generate() {
1294 this->preamble();
1295
1296 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
1297 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
1298 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
1299 compute_oh_loop_common();
1300 this->postamble();
1301}
1302
1303status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
1304 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
1305 const memory_desc_wrapper &diff_weights_d,
1306 const memory_desc_wrapper &diff_dst_d) {
1307 if (!mayiuse(avx2)) return status::unimplemented;
1308
1309 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1310 int ndims = src_d.ndims();
1311 jcp.ndims = ndims;
1312
1313 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
1314 jcp.mb = src_d.dims()[0];
1315
1316 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1317 jcp.oc_without_padding = jcp.oc;
1318 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1319 jcp.ic_without_padding = jcp.ic;
1320
1321 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1322 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
1323 jcp.iw = src_d.dims()[ndims - 1];
1324 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1325 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
1326 jcp.ow = diff_dst_d.dims()[ndims - 1];
1327
1328 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
1329 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
1330 jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
1331
1332 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1333 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1334 jcp.l_pad = cd.padding[0][ndims - 3];
1335
1336 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1337 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1338 jcp.stride_w = cd.strides[ndims - 3];
1339
1340 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1341 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1342 jcp.dilate_w = cd.dilates[ndims - 3];
1343
1344 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
1345 const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
1346 const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
1347 auto wei_tag_OIxio = with_groups
1348 ? pick(ndims - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o)
1349 : pick(ndims - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o);
1350 auto wei_tag_Oxio = with_groups ? pick(ndims - 3, gOwi8o, gOhwi8o, gOdhwi8o)
1351 : pick(ndims - 3, Owi8o, Ohwi8o, Odhwi8o);
1352
1353 jcp.src_tag
1354 = src_d.matches_one_of_tag(dat_tag_ncx, dat_tag_nxc, dat_tag_nCx8c);
1355 jcp.wei_tag
1356 = diff_weights_d.matches_one_of_tag(wei_tag_OIxio, wei_tag_Oxio);
1357 jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
1358
1359 bool is_data_layout_nxc
1360 = everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
1361
1362 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
1363
1364 const bool flat = jcp.ic == 3;
1365 const bool mimo = !flat;
1366
1367 const int simd_w = 8;
1368
1369 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1370 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1371 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1372 jcp.r_pad = nstl::max(0,
1373 calculate_end_padding(
1374 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
1375 jcp.b_pad = nstl::max(0,
1376 calculate_end_padding(
1377 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
1378 jcp.back_pad = nstl::max(0,
1379 calculate_end_padding(
1380 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
1381
1382 const int max_h_pad = ext_kh;
1383 const int max_w_pad = ext_kw;
1384 const bool boundaries_ok = true && jcp.t_pad < max_h_pad
1385 && jcp.b_pad < max_h_pad && jcp.l_pad < max_w_pad
1386 && jcp.r_pad < max_w_pad && jcp.f_pad == 0 && jcp.back_pad == 0;
1387 if (!boundaries_ok) return status::unimplemented;
1388
1389 bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1;
1390
1391 if (ok_to_pad_channels) {
1392 jcp.oc = rnd_up(jcp.oc, simd_w);
1393 if (mimo) jcp.ic = rnd_up(jcp.ic, simd_w);
1394 }
1395
1396 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % simd_w : 0;
1397 jcp.oc_tail = is_data_layout_nxc ? jcp.oc % simd_w : 0;
1398
1399 bool args_ok = true
1400 && IMPLICATION(flat,
1401 jcp.wei_tag == wei_tag_Oxio
1402 && ((jcp.src_tag == dat_tag_ncx
1403 && jcp.dst_tag == dat_tag_nCx8c)
1404 || (jcp.src_tag == dat_tag_nxc
1405 && jcp.dst_tag == dat_tag_nxc)))
1406 && IMPLICATION(mimo,
1407 jcp.wei_tag == wei_tag_OIxio
1408 && ((jcp.src_tag == dat_tag_nCx8c
1409 && jcp.dst_tag == dat_tag_nCx8c)
1410 || (jcp.src_tag == dat_tag_nxc
1411 && jcp.dst_tag == dat_tag_nxc)))
1412 && IMPLICATION(mimo && !is_data_layout_nxc, jcp.ic % simd_w == 0)
1413 && IMPLICATION(!is_data_layout_nxc, jcp.oc % simd_w == 0)
1414 && jcp.kw < 14 && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */
1415 && jcp.kh <= jcp.ih /* [bwd_w:r2] */
1416 && jcp.kd <= jcp.f_pad + jcp.id && jcp.kd <= jcp.id
1417 && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */
1418 && jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0
1419 && jcp.ic <= src_d.padded_dims()[1]
1420 && jcp.oc <= diff_dst_d.padded_dims()[1];
1421 if (!args_ok) return status::unimplemented;
1422
1423 jcp.ic_block = flat ? jcp.ic : simd_w;
1424 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
1425
1426 jcp.oc_block = simd_w;
1427 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
1428 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1429
1430 return status::success;
1431}
1432
1433void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(
1434 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1435 if (jcp.with_bias && (jcp.oc_without_padding % jcp.oc_block != 0)) {
1436 const size_t nelems_padded_bias
1437 = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block);
1438 scratchpad.book<float>(key_conv_padded_bias, nelems_padded_bias);
1439 }
1440}
1441
1442inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() {
1443 Label kd_comeback_loop;
1444 mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0
1445 L(kd_comeback_loop);
1446 {
1447 sub(aux_reg_input, get_input_offset(0, jcp.iw * jcp.ih));
1448 sub(aux_reg_kernel, get_kernel_offset(jcp.kw * jcp.kh, 0));
1449 dec(kj);
1450 cmp(kj, 0);
1451 jg(kd_comeback_loop, T_NEAR);
1452 }
1453}
1454
1455inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() {
1456 mov(kj, reg_kh);
1457 Label kh_comeback_loop;
1458 L(kh_comeback_loop);
1459 {
1460 sub(reg_input, get_input_offset(0, jcp.iw));
1461 sub(reg_kernel, get_kernel_offset(jcp.kw, 0));
1462 dec(kj);
1463 cmp(kj, 0);
1464 jg(kh_comeback_loop, T_NEAR);
1465 }
1466}
1467
1468inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step(
1469 int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset,
1470 int kernel_offset, int output_offset) {
1471
1472 if (ic_block_step <= 0) return;
1473
1474 const int kw = jcp.kw;
1475 const int oc_tail = jcp.oc_tail;
1476
1477 if (oc_tail) {
1478 push(reg_kh);
1479 mov(reg_ci_flag, ptr[param1 + GET_OFF(flags)]);
1480 }
1481
1482 auto load_compute_store = [=](bool is_tail) {
1483 for (int i_kw = 0; i_kw < kw; i_kw++)
1484 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1485 size_t off = get_kernel_offset(i_kw, i_ic) + kernel_offset;
1486 if (is_tail)
1487 load_bytes(Ymm(i_kw * ic_block_step + i_ic), reg_kernel,
1488 off, oc_tail * sizeof(float));
1489 else
1490 vmovups(Ymm(i_kw * ic_block_step + i_ic),
1491 yword[reg_kernel + off]);
1492 }
1493
1494 for (int i_ur = 0; i_ur < ur_w; i_ur++) {
1495 if (is_tail)
1496 load_bytes(Ymm(kw * ic_block_step + 0), reg_output,
1497 get_output_offset(0, i_ur) + output_offset,
1498 oc_tail * sizeof(float));
1499 else
1500 vmovups(Ymm(kw * ic_block_step + 0),
1501 yword[reg_output + get_output_offset(0, i_ur)
1502 + output_offset]);
1503
1504 for (int i_kw = 0; i_kw < kw; i_kw++) {
1505 int i_iw = i_ur * jcp.stride_w + i_kw;
1506 if (i_iw - pad_l < 0
1507 || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r)
1508 continue;
1509 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1510 size_t i_off = get_input_offset(i_ic, i_iw - pad_l);
1511 vbroadcastss(Ymm(kw * ic_block_step + 1),
1512 make_safe_addr(reg_input, i_off, reg_long_offt));
1513 vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic),
1514 Ymm(kw * ic_block_step + 0),
1515 Ymm(kw * ic_block_step + 1));
1516 }
1517 }
1518 }
1519
1520 for (int i_kw = 0; i_kw < kw; i_kw++)
1521 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1522 size_t off = get_kernel_offset(i_kw, i_ic) + kernel_offset;
1523 if (is_tail)
1524 store_bytes(Ymm(i_kw * ic_block_step + i_ic), reg_kernel,
1525 off, oc_tail * sizeof(float));
1526
1527 else
1528 vmovups(yword[reg_kernel + off],
1529 Ymm(i_kw * ic_block_step + i_ic));
1530 }
1531 };
1532
1533 if (oc_tail) {
1534 Label load_tail, load_done;
1535 test(reg_ci_flag, FLAG_OC_LAST);
1536 jne(load_tail, T_NEAR);
1537
1538 load_compute_store(false);
1539 jmp(load_done, T_NEAR);
1540
1541 L(load_tail);
1542 load_compute_store(true);
1543
1544 L(load_done);
1545 } else {
1546 load_compute_store(false);
1547 }
1548
1549 if (oc_tail) pop(reg_kh);
1550}
1551
1552inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp() {
1553 int ic_block_step;
1554 if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
1555 ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block;
1556 } else if (one_of(jcp.src_tag, nwc, nhwc, ndhwc)) {
1557 ic_block_step = jcp.kw > 7 ? 1 : jcp.kw > 3 ? 2 : jcp.kw > 1 ? 4 : 8;
1558 if (jcp.ic_block % ic_block_step != 0) {
1559 ic_block_step = jcp.ic_block < ic_block_step ? jcp.ic_block : 1;
1560 }
1561 if (jcp.ic < ic_block_step) ic_block_step = jcp.ic;
1562 } else {
1563 ic_block_step = jcp.kw > 7 ? 1 : jcp.kw > 3 ? 2 : jcp.kw > 1 ? 4 : 8;
1564 }
1565
1566 const int max_ur_w = jcp.ow > 56 ? 14 : 28;
1567
1568 if (jcp.ow <= max_ur_w || one_of(jcp.src_tag, nwc, nhwc, ndhwc))
1569 compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
1570 else
1571 compute_oh_step_common(ic_block_step, max_ur_w);
1572
1573 if (jcp.ndims == 5) {
1574 od_step_comeback_pointers();
1575 mov(reg_input, aux_reg_input);
1576 mov(reg_kernel, aux_reg_kernel);
1577 } else {
1578 oh_step_comeback_pointers();
1579 }
1580}
1581
1582inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow(
1583 int ic_block_step, int max_ur_w) {
1584 UNUSED(max_ur_w);
1585
1586 const int r_pad = jcp.r_pad;
1587 const int ic_tail = jcp.ic_tail;
1588 const int ic_block = jcp.ic_block;
1589 const int ic_block_step_tail = jcp.ic % ic_block_step;
1590 const size_t inp_icblk_stride = get_input_offset(ic_block_step, 0);
1591
1592 if (ic_tail) {
1593 push(reg_ih_count);
1594 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
1595 }
1596
1597 Label kd_loop;
1598 if (jcp.ndims == 5) {
1599 mov(aux_reg_input, reg_input);
1600 mov(aux_reg_kernel, reg_kernel);
1601 mov(ki, jcp.kd);
1602 L(kd_loop);
1603 mov(reg_input, aux_reg_input);
1604 mov(reg_kernel, aux_reg_kernel);
1605 }
1606
1607 mov(kj, reg_kh);
1608 Label kh_loop, kh_loop_ic_tail, kh_loop_done;
1609 if (ic_tail) {
1610 cmp(reg_channel, ic_block);
1611 jl(kh_loop_ic_tail, T_NEAR);
1612 }
1613
1614 L(kh_loop);
1615 {
1616 xor_(b_ic, b_ic);
1617 Label ic_block_loop;
1618 L(ic_block_loop);
1619 {
1620 compute_ic_block_step(
1621 jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0, 0, 0);
1622 safe_add(reg_input, inp_icblk_stride, reg_long_offt);
1623 add(reg_kernel, get_kernel_offset(0, ic_block_step));
1624 add(b_ic, ic_block_step);
1625 cmp(b_ic, ic_block);
1626 jl(ic_block_loop, T_NEAR);
1627 }
1628 add(reg_input,
1629 get_input_offset(0, jcp.iw) - get_input_offset(ic_block, 0));
1630 add(reg_kernel, get_kernel_offset((jcp.kw - 1), 0));
1631 dec(kj);
1632 cmp(kj, 0);
1633 jg(kh_loop, T_NEAR);
1634 }
1635 jmp(kh_loop_done, T_NEAR);
1636
1637 L(kh_loop_ic_tail);
1638 {
1639 Label ic_block_loop, ic_block_loop_done;
1640
1641 cmp(reg_channel, ic_block_step);
1642 jl(ic_block_loop_done, T_NEAR);
1643
1644 mov(b_ic, ic_tail);
1645 L(ic_block_loop);
1646 {
1647 compute_ic_block_step(
1648 jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0, 0, 0);
1649 safe_add(reg_input, inp_icblk_stride, reg_long_offt);
1650 add(reg_kernel, get_kernel_offset(0, ic_block_step));
1651 sub(b_ic, ic_block_step);
1652 cmp(b_ic, ic_block_step);
1653 jge(ic_block_loop, T_NEAR);
1654 }
1655
1656 L(ic_block_loop_done);
1657
1658 if (ic_block_step_tail) {
1659 compute_ic_block_step(
1660 jcp.ow, jcp.l_pad, r_pad, ic_block_step_tail, 0, 0, 0);
1661 add(reg_input, get_input_offset(ic_block_step_tail, 0));
1662 add(reg_kernel, get_kernel_offset(0, ic_block_step_tail));
1663 }
1664
1665 add(reg_input,
1666 get_input_offset(0, jcp.iw) - get_input_offset(ic_tail, 0));
1667 add(reg_kernel,
1668 get_kernel_offset(0, ic_block - ic_tail)
1669 + get_kernel_offset((jcp.kw - 1), 0));
1670 dec(kj);
1671 cmp(kj, 0);
1672 jg(kh_loop_ic_tail, T_NEAR);
1673 }
1674
1675 L(kh_loop_done);
1676
1677 if (jcp.ndims == 5) {
1678 add(aux_reg_input, get_input_offset(0, jcp.ih * jcp.iw));
1679 add(aux_reg_kernel, get_kernel_offset(jcp.kh * jcp.kw, 0));
1680 dec(ki);
1681 cmp(ki, 0);
1682 jg(kd_loop, T_NEAR);
1683 }
1684 if (ic_tail) pop(reg_ih_count);
1685}
1686
1687inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common(
1688 int ic_block_step, int max_ur_w) {
1689 // TODO: suppport channel tails for nxc format
1690
1691 const int ic_block = jcp.ic_block;
1692 const int stride_w = jcp.stride_w;
1693 Label kd_loop;
1694
1695 const int r_pad = jcp.r_pad;
1696
1697 int ur_w = nstl::min(jcp.ow, max_ur_w);
1698 int ur_w_trips = jcp.ow / ur_w;
1699 int ur_w_tail = jcp.ow % ur_w;
1700 if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) {
1701 if (ur_w_trips > 1) {
1702 ur_w_tail += ur_w;
1703 ur_w_trips--;
1704 } else {
1705 ur_w_tail += (ur_w - ur_w / 2);
1706 ur_w = ur_w / 2;
1707 }
1708 }
1709
1710 int input_comeback
1711 = get_input_offset(0, ur_w_trips * ur_w * stride_w - jcp.l_pad);
1712 int output_comeback = get_output_offset(0, ur_w_trips * ur_w);
1713
1714 if (jcp.ndims == 5) {
1715 mov(aux_reg_input, reg_input);
1716 mov(aux_reg_kernel, reg_kernel);
1717 mov(ki, jcp.kd);
1718 L(kd_loop);
1719 mov(reg_input, aux_reg_input);
1720 mov(reg_kernel, aux_reg_kernel);
1721 }
1722
1723 mov(kj, reg_kh);
1724 Label kh_loop;
1725 L(kh_loop);
1726 {
1727 xor_(b_ic, b_ic);
1728 Label ic_block_loop;
1729 L(ic_block_loop);
1730 {
1731 if (jcp.l_pad != 0) {
1732 ur_w_trips--;
1733 compute_ic_block_step(
1734 ur_w, jcp.l_pad, 0, ic_block_step, 0, 0, 0);
1735 add(reg_input,
1736 get_input_offset(0, ur_w * stride_w - jcp.l_pad));
1737 add(reg_output, get_output_offset(0, ur_w));
1738 }
1739
1740 if (ur_w_trips > 0) {
1741 xor_(reg_ur_w_trips, reg_ur_w_trips);
1742 Label ow_block_loop;
1743 L(ow_block_loop);
1744 {
1745 compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
1746 add(reg_output, get_output_offset(0, ur_w));
1747 add(reg_input, get_input_offset(0, ur_w * stride_w));
1748
1749 inc(reg_ur_w_trips);
1750 cmp(reg_ur_w_trips, ur_w_trips);
1751 jl(ow_block_loop, T_NEAR);
1752 }
1753 }
1754
1755 if (ur_w_tail > 0)
1756 compute_ic_block_step(
1757 ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0);
1758
1759 sub(reg_input, input_comeback);
1760 sub(reg_output, output_comeback);
1761
1762 size_t inp_icblk_stride = get_input_offset(ic_block_step, 0);
1763 safe_add(reg_input, inp_icblk_stride, reg_long_offt);
1764 add(reg_kernel, get_kernel_offset(0, ic_block_step));
1765
1766 add(b_ic, ic_block_step);
1767 cmp(b_ic, jcp.ic_block);
1768 jl(ic_block_loop, T_NEAR);
1769 }
1770 add(reg_input,
1771 get_input_offset(0, jcp.iw) - get_input_offset(ic_block, 0));
1772 add(reg_kernel, get_kernel_offset((jcp.kw - 1), 0));
1773 dec(kj);
1774 cmp(kj, 0);
1775 jg(kh_loop, T_NEAR);
1776 }
1777
1778 if (jcp.ndims == 5) {
1779 add(aux_reg_input, get_input_offset(0, jcp.ih * jcp.iw));
1780 add(aux_reg_kernel, get_kernel_offset(jcp.kh * jcp.kw, 0));
1781 dec(ki);
1782 cmp(ki, 0);
1783 jg(kd_loop, T_NEAR);
1784 }
1785}
1786
1787inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common() {
1788 const int t_pad = jcp.t_pad;
1789 const int stride_h = jcp.stride_h;
1790 int b_pad = jcp.b_pad;
1791
1792 Label oh_tpad_loop, oh_loop, oh_loop_end;
1793
1794 mov(reg_kh, jcp.kh);
1795 xor_(reg_ih_count, reg_ih_count);
1796 xor_(reg_oj, reg_oj);
1797 if (t_pad > 0) {
1798 assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */
1799 mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih);
1800 add(reg_kernel, get_kernel_offset(t_pad * jcp.kw, 0));
1801
1802 L(oh_tpad_loop);
1803 {
1804 compute_oh_step_disp();
1805 add(reg_output, get_output_offset(0, jcp.ow));
1806 sub(reg_kernel, get_kernel_offset(stride_h * jcp.kw, 0));
1807
1808 inc(reg_oj);
1809 add(reg_ih_count, stride_h);
1810 add(reg_kh, stride_h);
1811
1812 /* the overlap between input and kernel may not reach kernel size.
1813 * so far we do not support that (until we put constant here) */
1814 const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */
1815 cmp(reg_kh, final_inp_ker_overlap);
1816 jl(oh_tpad_loop, T_NEAR);
1817 }
1818
1819 if (t_pad % stride_h != 0) {
1820 int inp_corr = stride_h - t_pad % stride_h;
1821 add(reg_kernel, get_kernel_offset(inp_corr * jcp.kw, 0));
1822 add(reg_input, get_input_offset(0, inp_corr * jcp.iw));
1823 }
1824 }
1825 cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
1826 jge(oh_loop_end, T_NEAR);
1827 cmp(reg_oj, jcp.oh);
1828 jge(oh_loop, T_NEAR);
1829
1830 mov(reg_kh, jcp.kh);
1831 L(oh_loop);
1832 {
1833 compute_oh_step_disp();
1834 add(reg_input, get_input_offset(0, stride_h * jcp.iw));
1835 add(reg_output, get_output_offset(0, jcp.ow));
1836
1837 inc(reg_oj);
1838 add(reg_ih_count, stride_h);
1839
1840 cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
1841 jge(oh_loop_end, T_NEAR);
1842
1843 cmp(reg_oj, jcp.oh);
1844 jl(oh_loop, T_NEAR);
1845 }
1846 L(oh_loop_end);
1847 if (b_pad > 0) {
1848 Label oh_bpad_loop, oh_bpad_loop_end;
1849 cmp(reg_oj, jcp.oh);
1850 jge(oh_bpad_loop_end, T_NEAR);
1851
1852 mov(reg_kh, jcp.ih + t_pad);
1853 sub(reg_kh, reg_ih_count);
1854 L(oh_bpad_loop);
1855 {
1856 compute_oh_step_disp();
1857 add(reg_input, get_input_offset(0, stride_h * jcp.iw));
1858 add(reg_output, get_output_offset(0, jcp.ow));
1859
1860 sub(reg_kh, stride_h);
1861 cmp(reg_kh, 0);
1862 jle(oh_bpad_loop_end, T_NEAR);
1863
1864 inc(reg_oj);
1865 cmp(reg_oj, jcp.oh);
1866 jl(oh_bpad_loop, T_NEAR);
1867 }
1868 L(oh_bpad_loop_end);
1869 }
1870}
1871
1872} // namespace x64
1873} // namespace cpu
1874} // namespace impl
1875} // namespace dnnl
1876
1877// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1878