1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/memory.hpp"
19#include "common/nstl.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/x64/jit_uni_dw_conv_kernel_f32.hpp"
24
25#define GET_OFF(field) offsetof(jit_conv_call_s, field)
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32using namespace dnnl::impl::prop_kind;
33using namespace dnnl::impl::memory_tracking::names;
34using namespace dnnl::impl::utils;
35
36using namespace Xbyak;
37
38template <cpu_isa_t isa>
39jit_uni_dw_conv_fwd_kernel_f32<isa>::jit_uni_dw_conv_fwd_kernel_f32(
40 const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md)
41 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa), jcp(ajcp) {
42 if (jcp.with_eltwise || jcp.with_binary) {
43 using namespace binary_injector;
44 static constexpr bool preserve_gpr = true;
45 static constexpr bool preserve_vmm = false;
46 static constexpr size_t helper_vmm_idx = 31;
47 static constexpr bool use_exact_tail_scalar_bcast = true;
48 const size_t tail_size = jcp.oc_without_padding
49 % (cpu_isa_traits<isa>::vlen / sizeof(float));
50 rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r14, r15,
51 r12, preserve_gpr, preserve_vmm,
52 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
53 memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask,
54 use_exact_tail_scalar_bcast};
55 static_params_t static_params {this->param1, rhs_arg_static_params};
56
57 postops_injector_
58 = utils::make_unique<injector::jit_uni_postops_injector_t<isa>>(
59 this, jcp.post_ops, static_params);
60 }
61}
62
63bool check_if_tail_load(const bool is_ch_tail, const int c_tail, const int ch,
64 const int ur_ch_blocks, const int vlen, const int i) {
65 return is_ch_tail && (ch + 1 == ur_ch_blocks) && ((i + 1) * vlen > c_tail);
66}
67
68template <cpu_isa_t isa>
69void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src(
70 int ur_ch_blocks, int ur_w, bool is_ch_tail) {
71
72 const auto dst_layout_nxc = is_dst_layout_nxc();
73 const auto ch_blk = jcp.ch_block;
74 const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk;
75 const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk;
76 const int vlen = cpu_isa_traits<isa>::vlen / sizeof(float);
77 const int c_tail = jcp.oc % jcp.ch_block;
78
79 const int repeats = max_repeats();
80 for (int i = 0; i < repeats; i++) {
81 for (int ch = 0; ch < ur_ch_blocks; ch++) {
82 const bool is_tail_load = check_if_tail_load(
83 is_ch_tail, c_tail, ch, ur_ch_blocks, vlen, i);
84 if ((ch + 1 == ur_ch_blocks) && is_ch_tail && c_tail <= i * vlen)
85 continue;
86 for (int ow = 0; ow < ur_w; ow++) {
87 Vmm vmm_acc
88 = get_acc_reg(i * ur_ch_blocks * ur_w + ch * ur_w + ow);
89
90 const int b_off = ch * ch_blk + i * vlen;
91 if (this->jcp.with_bias) {
92 if (is_tail_load) {
93 load_tail(vmm_acc, reg_bias, b_off * sizeof(float),
94 (c_tail - i * vlen) * sizeof(float));
95 } else {
96 uni_vmovups(vmm_acc,
97 vmmword[reg_bias + b_off * sizeof(float)]);
98 }
99 } else {
100 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
101 }
102
103 const int o_off = ch * ocb_stride + ow * ow_stride + i * vlen;
104 if (this->jcp.with_sum) {
105 if (is_tail_load) {
106 if (this->jcp.with_bias) {
107 // using ker_vmm as vmm_tmp as it is safe to do so.
108 auto vmm_tmp = get_ker_reg(0);
109 add_tail_from_mem(vmm_acc, vmm_tmp, reg_output,
110 o_off * sizeof(float),
111 (c_tail - i * vlen) * sizeof(float));
112 } else {
113 // nothing to add, just load dst.
114 load_tail(vmm_acc, reg_output,
115 o_off * sizeof(float),
116 c_tail * sizeof(float));
117 }
118 } else {
119 // blocked layout has dst padded, so no tail handling.
120 uni_vaddps(vmm_acc, vmm_acc,
121 vmmword[reg_output + o_off * sizeof(float)]);
122 }
123 }
124 }
125 }
126 }
127}
128
129template <cpu_isa_t isa>
130void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
131 int ur_ch_blocks, int ur_w, int pad_l, int pad_r, bool is_ch_tail) {
132 int ch_blk = jcp.ch_block;
133 int dilate_h = jcp.dilate_h + 1;
134 int dilate_w = jcp.dilate_w + 1;
135 int stride_w = jcp.stride_w;
136
137 const auto src_layout_nxc = is_src_layout_nxc();
138 const auto iw_stride = src_layout_nxc ? jcp.ngroups : ch_blk;
139 const auto ih_stride = jcp.iw * iw_stride;
140 const auto icb_stride = src_layout_nxc
141 ? ch_blk
142 : (jcp.is_fused_conv ? 1 : jcp.ih) * jcp.iw * ch_blk;
143 const int vlen = cpu_isa_traits<isa>::vlen / sizeof(float);
144
145 auto get_input_spatial_index = [=](int oi, int ki) {
146 return (ki * dilate_w + oi * stride_w - pad_l);
147 };
148
149 auto get_input_offset = [=](int ii, int ci, int rep) {
150 return (ci * icb_stride + ii * iw_stride + rep * vlen)
151 * jcp.typesize_in;
152 };
153
154 int ii_start = 0;
155 int ii_end = -1;
156 if (jcp.is_resrc_depthwise) {
157 // find bounds of input spatial indices
158 bool first = true;
159 for (int ki = 0; ki < jcp.kw; ki++) {
160 int oi_start = get_ow_start(ki, pad_l);
161 int oi_end = get_ow_end(ur_w, ki, pad_r);
162 for (int oi = oi_start; oi < oi_end; oi++) {
163 int ii = get_input_spatial_index(oi, ki);
164 if (first || ii < ii_start) ii_start = ii;
165 if (first || ii > ii_end) ii_end = ii;
166 first = false;
167 }
168 }
169 }
170
171 Label iter_exit_label;
172
173 cmp(reg_kh, 0);
174 je(iter_exit_label, T_NEAR);
175
176 mov(iter_kh, reg_kh);
177 Label kh_label;
178 L(kh_label);
179 {
180 if (jcp.is_fused_conv) {
181 mov(aux_reg_input, ptr[aux_reg_input_buffer_ptr]);
182 add(aux_reg_input, reg_iw_offset);
183 }
184 const int c_tail = jcp.oc % jcp.ch_block;
185 const int repeats = max_repeats();
186 for (int i = 0; i < repeats; i++) {
187 for (int ch = 0; ch < ur_ch_blocks; ch++) {
188 const bool is_tail_load = check_if_tail_load(
189 is_ch_tail, c_tail, ch, ur_ch_blocks, vlen, i);
190 if ((ch + 1 == ur_ch_blocks) && is_ch_tail
191 && c_tail <= i * vlen)
192 continue;
193 if (jcp.is_resrc_depthwise) {
194 // now we can load input once and reuse up to jcp.kw times
195 for (int ii = ii_start; ii <= ii_end; ii++) {
196 Vmm vmm_src = get_src_reg(ii);
197 const int inp_off = get_input_offset(ii, ch, i);
198 if (is_tail_load) {
199 load_tail(vmm_src, aux_reg_input, inp_off,
200 (c_tail - i * vlen) * jcp.typesize_in);
201 } else {
202 uni_vmovups(vmm_src, ptr[aux_reg_input + inp_off]);
203 }
204 }
205 }
206 for (int kw = 0; kw < jcp.kw; kw++) {
207 const int ker_off = ch * jcp.kh * jcp.kw * ch_blk
208 + kw * ch_blk + i * vlen;
209
210 Vmm vmm_ker = get_ker_reg(0);
211 uni_vmovups(vmm_ker,
212 ptr[aux_reg_kernel + ker_off * sizeof(float)]);
213
214 int ow_start = get_ow_start(kw, pad_l);
215 int ow_end = get_ow_end(ur_w, kw, pad_r);
216 for (int ow = ow_start; ow < ow_end; ow++) {
217
218 const int ii = get_input_spatial_index(ow, kw);
219 Vmm vmm_src = jcp.is_resrc_depthwise ? get_src_reg(ii)
220 : get_src_reg(0);
221 if (!jcp.is_resrc_depthwise) {
222 const int inp_off = get_input_offset(ii, ch, i);
223 if (is_tail_load) {
224 load_tail(vmm_src, aux_reg_input, inp_off,
225 (c_tail - i * vlen) * jcp.typesize_in);
226 } else {
227 uni_vmovups(
228 vmm_src, ptr[aux_reg_input + inp_off]);
229 }
230 }
231 Vmm vmm_acc = get_acc_reg(
232 i * ur_ch_blocks * ur_w + ch * ur_w + ow);
233 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
234 }
235 }
236 }
237 }
238
239 add(aux_reg_kernel, jcp.kw * ch_blk * sizeof(float));
240 if (jcp.is_fused_conv) {
241 // Move to next row pointer in the buffer
242 add(aux_reg_input_buffer_ptr, sizeof(void *));
243 } else {
244 add(aux_reg_input, ih_stride * dilate_h * sizeof(float));
245 }
246
247 dec(iter_kh);
248 cmp(iter_kh, 0);
249 jg(kh_label, T_NEAR);
250 }
251
252 L(iter_exit_label);
253}
254
255template <typename F>
256void iterate(const int repeats, const int ur_ch_blocks, const int ur_w,
257 const bool mask_tail, const F &f) {
258 for (int r = 0; r < repeats; r++)
259 for (int ch = 0; ch < ur_ch_blocks; ch++) {
260 const bool mask_flag = mask_tail && ch + 1 == ur_ch_blocks;
261 for (int ow = 0; ow < ur_w; ow++)
262 f(r, ch, ow, mask_flag);
263 }
264}
265
266template <typename F>
267void iterate(
268 const int repeats, const int ur_ch_blocks, const int ur_w, const F &f) {
269 iterate(repeats, ur_ch_blocks, ur_w, false, f);
270}
271
272template <cpu_isa_t isa>
273void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_postops(
274 const int ur_ch_blocks, const int ur_w, const bool is_ch_tail) {
275 if (this->jcp.with_eltwise || this->jcp.with_binary) {
276 const int repeats = max_repeats();
277 injector_utils::vmm_index_set_t vmm_idxs;
278 if (jcp.with_binary) {
279 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
280 rhs_arg_params_tail;
281 const auto dst_layout_nxc = is_dst_layout_nxc();
282 const auto ch_blk = jcp.ch_block;
283 const auto ocb_stride
284 = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk;
285 const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk;
286 const auto mask_tail_blocked_layout
287 = jcp.oc_without_padding % jcp.ch_block && !dst_layout_nxc;
288 const int c_tail = jcp.oc_without_padding % jcp.ch_block;
289 iterate(repeats, ur_ch_blocks, ur_w, mask_tail_blocked_layout,
290 [&](const int r, const int ch, const int ow,
291 const bool mask_flag_blocked_layout) {
292 const int vlen
293 = cpu_isa_traits<isa>::vlen / sizeof(float);
294 const bool is_tail_load = check_if_tail_load(
295 is_ch_tail, c_tail, ch, ur_ch_blocks, vlen, r);
296 if ((ch + 1 == ur_ch_blocks) && is_ch_tail
297 && c_tail <= r * vlen)
298 return;
299 const size_t o_off = jcp.typesize_out
300 * (ch * ocb_stride + ow * ow_stride + r * vlen);
301 const auto vmm_idx = get_acc_reg_idx(
302 r * ur_ch_blocks * ur_w + ch * ur_w + ow);
303 vmm_idxs.emplace(vmm_idx);
304
305 rhs_arg_params_tail.vmm_idx_to_out_reg.emplace(
306 vmm_idx, reg_output);
307 rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
308 vmm_idx, o_off);
309 if (mask_flag_blocked_layout || is_tail_load)
310 rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
311 });
312 rhs_arg_params = rhs_arg_params_tail;
313 rhs_arg_params.vmm_tail_idx_.clear();
314
315 Label postops_done;
316 if (mask_tail_blocked_layout) {
317 // mask_tail_blocked_layout approach of dynamic tail handling is
318 // used in blocked layout only. TODO: may be unify?
319 Label postops_no_tail;
320 mov(reg_tmp, ptr[param1 + GET_OFF(load_work)]);
321 cmp(reg_tmp, jcp.nb_ch_blocking * jcp.ch_block);
322 jge(postops_no_tail, T_NEAR);
323 postops_injector_->compute_vector_range(
324 vmm_idxs, rhs_arg_params_tail);
325 jmp(postops_done, T_NEAR);
326 L(postops_no_tail);
327 } else if (is_ch_tail) {
328 postops_injector_->compute_vector_range(
329 vmm_idxs, rhs_arg_params_tail);
330 }
331 if (!is_ch_tail) {
332 postops_injector_->compute_vector_range(
333 vmm_idxs, rhs_arg_params);
334 L(postops_done);
335 }
336 } else {
337 iterate(repeats, ur_ch_blocks, ur_w,
338 [&](const int r, const int ch, const int ow, const bool) {
339 vmm_idxs.emplace(get_acc_reg_idx(
340 r * ur_ch_blocks * ur_w + ch * ur_w + ow));
341 });
342 postops_injector_->compute_vector_range(vmm_idxs);
343 }
344 }
345}
346
347template <cpu_isa_t isa>
348void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_tail(
349 Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset, int load_size) {
350 uni_vmovups(vmm | k_oc_tail_mask | T_z, ptr[reg + offset]);
351}
352
353template <>
354void jit_uni_dw_conv_fwd_kernel_f32<avx2>::load_tail(
355 Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset, int load_size) {
356 load_bytes(vmm, reg, offset, load_size);
357}
358
359template <>
360void jit_uni_dw_conv_fwd_kernel_f32<sse41>::load_tail(
361 Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset, int load_size) {
362 load_bytes(vmm, reg, offset, load_size);
363}
364
365template <cpu_isa_t isa>
366void jit_uni_dw_conv_fwd_kernel_f32<isa>::add_tail_from_mem(Vmm &vmm_acc,
367 Vmm &vmm_tmp, const Xbyak::Reg64 &reg, int64_t offset, int load_size) {
368 uni_vaddps(vmm_acc | k_oc_tail_mask | T_z, vmm_acc, ptr[reg + offset]);
369}
370
371template <>
372void jit_uni_dw_conv_fwd_kernel_f32<avx2>::add_tail_from_mem(Vmm &vmm_acc,
373 Vmm &vmm_tmp, const Xbyak::Reg64 &reg, int64_t offset, int load_size) {
374 load_bytes(vmm_tmp, reg, offset, load_size);
375 uni_vaddps(vmm_acc, vmm_acc, vmm_tmp);
376}
377
378template <>
379void jit_uni_dw_conv_fwd_kernel_f32<sse41>::add_tail_from_mem(Vmm &vmm_acc,
380 Vmm &vmm_tmp, const Xbyak::Reg64 &reg, int64_t offset, int load_size) {
381 load_bytes(vmm_tmp, reg, offset, load_size);
382 uni_vaddps(vmm_acc, vmm_acc, vmm_tmp);
383}
384
385template <cpu_isa_t isa>
386void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_tail(
387 Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset, int store_size) {
388 uni_vmovups(vmmword[reg + offset], vmm | k_oc_tail_mask);
389}
390
391template <>
392void jit_uni_dw_conv_fwd_kernel_f32<avx2>::store_tail(
393 Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset, int store_size) {
394 store_bytes(vmm, reg, offset, store_size);
395}
396
397template <>
398void jit_uni_dw_conv_fwd_kernel_f32<sse41>::store_tail(
399 Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset, int store_size) {
400 store_bytes(vmm, reg, offset, store_size);
401}
402
403template <cpu_isa_t isa>
404void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst(
405 int ur_ch_blocks, int ur_w, bool is_ch_tail) {
406
407 const auto dst_layout_nxc = is_dst_layout_nxc();
408 const auto ch_blk = jcp.ch_block;
409 const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk;
410 const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk;
411 const int vlen = cpu_isa_traits<isa>::vlen / sizeof(float);
412 const int c_tail = jcp.oc_without_padding % jcp.ch_block;
413
414 const int repeats = max_repeats();
415 for (int i = 0; i < repeats; i++) {
416 for (int ch = 0; ch < ur_ch_blocks; ch++) {
417 const bool is_tail_load = check_if_tail_load(
418 is_ch_tail, c_tail, ch, ur_ch_blocks, vlen, i);
419 if ((ch + 1 == ur_ch_blocks) && is_ch_tail && c_tail <= i * vlen)
420 continue;
421 for (int ow = 0; ow < ur_w; ow++) {
422 const int o_off = ch * ocb_stride + ow * ow_stride + i * vlen;
423 Vmm vmm_dst
424 = get_acc_reg(i * ur_ch_blocks * ur_w + ch * ur_w + ow);
425 if (is_tail_load) {
426 store_tail(vmm_dst, reg_output, o_off * sizeof(float),
427 (c_tail - i * vlen) * sizeof(float));
428 } else
429 uni_vmovups(vmmword[reg_output + o_off * sizeof(float)],
430 vmm_dst);
431 }
432 }
433 }
434}
435
436template <cpu_isa_t isa>
437void jit_uni_dw_conv_fwd_kernel_f32<isa>::compute_loop(
438 int ur_w, int ur_ch_blocks, int pad_l, int pad_r) {
439
440 const bool ch_loop = ur_ch_blocks > jcp.nb_ch_blocking;
441 // ch_loop currently happen only when data layout is nxc. The strides are
442 // calculated for this layout only.
443 const size_t wei_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.kh * jcp.kw
444 * jcp.ch_block * jcp.typesize_in;
445 const size_t inp_ch_stride
446 = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_in;
447 const size_t out_ch_stride
448 = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_out;
449 const size_t bias_stride
450 = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float);
451
452 auto compute = [&](int ur_ch_blocks, bool is_ch_tail) {
453 if (jcp.is_fused_conv) {
454 mov(aux_reg_input_buffer_ptr, reg_input_buffer_ptr);
455 } else {
456 mov(aux_reg_input, reg_input);
457 }
458
459 mov(aux_reg_kernel, reg_kernel);
460 load_src(ur_ch_blocks, ur_w, is_ch_tail);
461 apply_filter_unrolled(ur_ch_blocks, ur_w, pad_l, pad_r, is_ch_tail);
462 apply_postops(ur_ch_blocks, ur_w, is_ch_tail);
463 store_dst(ur_ch_blocks, ur_w, is_ch_tail);
464 };
465
466 mov(aux_reg_ch_blocks, reg_ch_blocks);
467 if (ch_loop) {
468 Label ch_loop_label, ch_tail_label, skip_ch_tail_label;
469 const int ch_block_tail = jcp.nb_ch
470 - (utils::rnd_dn(jcp.oc / jcp.ch_block, jcp.nb_ch_blocking));
471 const int ch_step = jcp.nb_ch_blocking * jcp.ch_block;
472
473 push(reg_kernel);
474 push(reg_input);
475 push(reg_output);
476 if (jcp.with_bias) push(reg_bias);
477
478 if ((jcp.oc / jcp.ch_block) >= jcp.nb_ch_blocking) {
479 if (ch_block_tail) {
480 cmp(aux_reg_ch_blocks, ch_step);
481 jl(ch_tail_label, T_NEAR);
482 }
483
484 L(ch_loop_label);
485 {
486 compute(jcp.nb_ch_blocking, false);
487 add(reg_kernel, wei_ch_stride);
488 add(reg_input, inp_ch_stride);
489 add(reg_output, out_ch_stride);
490 if (jcp.with_bias) add(reg_bias, bias_stride);
491 sub(aux_reg_ch_blocks, ch_step);
492 cmp(aux_reg_ch_blocks, ch_step);
493 jge(ch_loop_label, T_NEAR);
494 }
495 }
496
497 if (ch_block_tail) {
498 // ch work range [1, jcp.nb_ch_blocking * ch_block)
499 L(ch_tail_label);
500 cmp(aux_reg_ch_blocks, 0);
501 jle(skip_ch_tail_label, T_NEAR);
502 compute(ch_block_tail, jcp.oc % jcp.ch_block);
503 L(skip_ch_tail_label);
504 }
505
506 if (jcp.with_bias) pop(reg_bias);
507 pop(reg_output);
508 pop(reg_input);
509 pop(reg_kernel);
510
511 } else {
512 compute(ur_ch_blocks, jcp.oc % jcp.ch_block);
513 }
514}
515
516template <cpu_isa_t isa>
517void jit_uni_dw_conv_fwd_kernel_f32<isa>::ow_loop(int ur_ch_blocks) {
518
519 int iw = jcp.iw;
520 int ow = jcp.ow;
521 int kw = jcp.kw;
522 int l_pad = jcp.l_pad;
523 int ur_w = jcp.ur_w;
524 int ur_w_tail = jcp.ur_w_tail;
525 int stride_w = jcp.stride_w;
526
527 const auto src_layout_nxc = is_src_layout_nxc();
528 const auto dat_c_stride = src_layout_nxc ? jcp.ngroups : jcp.ch_block;
529 size_t inp_shift = (size_t)jcp.typesize_in * ur_w * stride_w * dat_c_stride;
530 size_t out_shift = (size_t)jcp.typesize_out * ur_w * dat_c_stride;
531
532 int inp_shift_pad
533 = jcp.typesize_in * (ur_w * stride_w - l_pad) * dat_c_stride;
534
535 int r_pad = nstl::max(0, jcp.r_pad);
536 int n_oi = ow / ur_w;
537 int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w,
538 calculate_extended_filter_size(kw, jcp.dilate_w));
539
540 assert(jcp.nb_ow <= 1);
541
542 if (r_pad1 > 0) n_oi--;
543 xor_(reg_oi, reg_oi);
544 if (ow == ur_w) {
545 compute_loop(ur_w, ur_ch_blocks, l_pad, r_pad);
546 } else {
547 if (n_oi == 0) {
548 compute_loop(ur_w, ur_ch_blocks, l_pad, r_pad1);
549 add(reg_input, inp_shift_pad);
550 add(reg_output, out_shift);
551 if (ur_w_tail != 0) {
552 compute_loop(ur_w_tail, ur_ch_blocks, 0, r_pad);
553 }
554 } else {
555 if (l_pad > 0) {
556 compute_loop(ur_w, ur_ch_blocks, l_pad, 0);
557 add(reg_input, inp_shift_pad);
558 add(reg_output, out_shift);
559 inc(reg_oi);
560 }
561 if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
562 Label ow_loop_label;
563 L(ow_loop_label);
564 {
565 compute_loop(ur_w, ur_ch_blocks, 0, 0);
566 add(reg_input, inp_shift);
567 add(reg_output, out_shift);
568
569 inc(reg_oi);
570 cmp(reg_oi, n_oi);
571 jl(ow_loop_label, T_NEAR);
572 }
573 }
574 if (r_pad1 > 0) {
575 compute_loop(ur_w, ur_ch_blocks, 0, r_pad1);
576 add(reg_input, inp_shift);
577 add(reg_output, out_shift);
578 }
579 if (ur_w_tail != 0) {
580 compute_loop(ur_w_tail, ur_ch_blocks, 0, r_pad);
581 }
582 }
583 }
584}
585
586template <cpu_isa_t isa>
587void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() {
588 this->preamble();
589
590 if (jcp.is_fused_conv) {
591 mov(reg_input_buffer_ptr, ptr[this->param1 + GET_OFF(src)]);
592 /* In case of fused depthwise convolution, `param.src` is not a pointer
593 to input, instead it points to a buffer containing pointers to
594 consecutive rows of input in format Cwc with blocking nb_ch_blocking.
595 Example: [ptr_to_inp_row0, ptr_to_inp_row1, ptr_to_inp_row2].
596 Traverse the data as
597 mov(reg_data, ptr[reg_input_buffer_ptr])
598 ... process row0 ...
599 add(reg_input_buffer_ptr, sizeof(void*))
600 mov(reg_data, ptr[reg_input_buffer_ptr])
601 ... process row1 ...
602 add(reg_input_buffer_ptr, sizeof(void*))
603 mov(reg_data, ptr[reg_input_buffer_ptr])
604 ... process row2 ...
605 */
606 xor_(reg_iw_offset, reg_iw_offset);
607 } else {
608 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
609 }
610 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
611 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
612 if (jcp.with_bias) mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
613 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
614 mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(load_work)]);
615
616 Label ch_blocks_tail_label;
617 Label exit_label;
618
619 int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
620 if (isa > avx2) {
621 const auto oc_tail = jcp.oc_without_padding % jcp.ch_block;
622 if (oc_tail != 0) {
623 // Prepare masks for tailing
624 const int oc_tail_shift
625 = jcp.ch_block - jcp.oc_without_padding % jcp.ch_block;
626 static constexpr auto zmm_full_mask = ((1 << 16) - 1);
627 Reg32 reg_tail_32 = reg_tail.cvt32();
628 mov(reg_tail_32, (zmm_full_mask >> oc_tail_shift));
629 kmovw(k_oc_tail_mask, reg_tail_32);
630 }
631 }
632
633 if (is_src_layout_nxc()) {
634 ow_loop(jcp.nb_ch);
635 } else {
636 cmp(reg_ch_blocks, (jcp.nb_ch_blocking - 1) * jcp.ch_block);
637 jle(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
638
639 ow_loop(jcp.nb_ch_blocking); // channel main loop
640
641 if (ch_blocks_tail) {
642 jmp(exit_label, T_NEAR);
643 L(ch_blocks_tail_label);
644 ow_loop(ch_blocks_tail); // channel tail loop
645 }
646
647 L(exit_label);
648 }
649
650 this->postamble();
651
652 if (jcp.with_eltwise) postops_injector_->prepare_table();
653}
654
655template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_core>;
656template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>;
657template struct jit_uni_dw_conv_fwd_kernel_f32<sse41>;
658
659template <cpu_isa_t isa>
660inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_vmm(
661 Vmm &vmm, const Xbyak::Address &addr, bool tail) {
662 int ch_tail = jcp.oc_without_padding % simd_w_; // special case for SSE41
663 int bytes = (tail && ch_tail > 0 ? ch_tail : simd_w_) * sizeof(float);
664 load_bytes(vmm, addr, bytes);
665}
666template <>
667inline void jit_uni_dw_conv_bwd_data_kernel_f32<avx2>::load_vmm(
668 Vmm &vmm, const Xbyak::Address &addr, bool tail) {
669 int bytes = (tail ? jcp.ch_tail : jcp.ch_block) * sizeof(float);
670 load_bytes(vmm, addr, bytes);
671}
672template <>
673inline void jit_uni_dw_conv_bwd_data_kernel_f32<avx512_core>::load_vmm(
674 Vmm &vmm, const Xbyak::Address &addr, bool tail) {
675 Zmm masked_vmm = tail ? vmm | k_ch_tail_mask | T_z : vmm;
676 vmovups(masked_vmm, addr);
677}
678
679template <cpu_isa_t isa>
680inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_vmm(
681 Vmm &vmm, const Xbyak::Address &addr, bool tail) {
682 int ch_tail = jcp.oc_without_padding % simd_w_; // special case for SSE41
683 int bytes = (tail && ch_tail > 0 ? ch_tail : simd_w_) * sizeof(float);
684 store_bytes(vmm, addr, bytes);
685}
686template <>
687inline void jit_uni_dw_conv_bwd_data_kernel_f32<avx2>::store_vmm(
688 Vmm &vmm, const Xbyak::Address &addr, bool tail) {
689 int bytes = (tail ? jcp.ch_tail : jcp.ch_block) * sizeof(float);
690 store_bytes(vmm, addr, bytes);
691}
692template <>
693inline void jit_uni_dw_conv_bwd_data_kernel_f32<avx512_core>::store_vmm(
694 Vmm &vmm, const Xbyak::Address &addr, bool tail) {
695 Zmm masked_vmm = tail ? vmm | k_ch_tail_mask : vmm;
696 vmovups(addr, masked_vmm);
697}
698
699template <cpu_isa_t isa>
700inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst(
701 int ur_ch_blocks, int ur_str_w) {
702 for (int i = 0; i < reg_repeats_; i++) {
703 for (int ch = 0; ch < ur_ch_blocks; ch++) {
704 for (int w = 0; w < ur_str_w; w++) {
705 Vmm vmm_acc = get_acc_reg(
706 i * ur_ch_blocks * ur_str_w + ch * ur_str_w + w);
707 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
708 }
709 }
710 }
711}
712
713template <cpu_isa_t isa>
714inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter(
715 int ur_ch_blocks, int ur_str_w, bool is_last_ch) {
716 int kw = jcp.kw;
717 int kh = jcp.kh;
718 int ow = jcp.ow;
719 int oh = jcp.oh;
720
721 int ch_blk = jcp.ch_block;
722 int stride_h = jcp.stride_h;
723 int stride_w = jcp.stride_w;
724
725 const bool ddst_layout_nxc = is_ddst_layout_nxc();
726 const size_t ch_block_step = ch_blk * (ddst_layout_nxc ? 1 : oh * ow);
727 const size_t sp_step = ddst_layout_nxc ? jcp.ngroups : ch_blk;
728
729 Label iter_exit_label;
730
731 cmp(reg_kh, 0);
732 je(iter_exit_label, T_NEAR);
733
734 cmp(reg_kw, 0);
735 je(iter_exit_label, T_NEAR);
736
737 mov(iter_kh, reg_kh);
738 Label kh_label;
739 L(kh_label);
740 {
741 mov(aux1_reg_ddst, aux_reg_ddst);
742 mov(aux1_reg_kernel, aux_reg_kernel);
743
744 mov(iter_kw, reg_kw);
745 Label kw_label;
746 L(kw_label);
747 {
748 for (int r = 0; r < reg_repeats_; r++) {
749 for (int ch = 0; ch < ur_ch_blocks; ch++) {
750 bool last_block = is_last_ch && ch == ur_ch_blocks - 1;
751 bool masked_load = last_block
752 && IMPLICATION(
753 isa == sse41, tail_simd_overlap(r + 1));
754
755 // sse41: if second simd_w is outside channel_block, skip
756 if (last_block && isa == sse41 && tail_simd_overlap(r))
757 break;
758
759 int ker_off = ch * kh * kw * ch_blk + r * simd_w_;
760 Vmm vmm_ker = get_ker_reg(0);
761 load_vmm(vmm_ker,
762 ptr[aux1_reg_kernel + ker_off * sizeof(float)],
763 masked_load);
764
765 for (int w = 0; w < ur_str_w; w++) {
766 size_t sp_offset = w * sp_step;
767 size_t ch_offset = ch * ch_block_step;
768 size_t ddst_off = static_cast<size_t>(
769 (sp_offset + ch_offset + r * simd_w_)
770 * sizeof(float));
771
772 Vmm vmm_ddst = get_ddst_reg(0);
773 load_vmm(vmm_ddst, ptr[aux1_reg_ddst + ddst_off],
774 masked_load);
775
776 Vmm vmm_acc = get_acc_reg(r * ur_ch_blocks * ur_str_w
777 + ch * ur_str_w + w);
778 uni_vfmadd231ps(vmm_acc, vmm_ddst, vmm_ker);
779 }
780 }
781 }
782
783 add(aux1_reg_kernel, ch_blk * stride_w * sizeof(float));
784 sub(aux1_reg_ddst, sp_step * sizeof(float));
785
786 sub(iter_kw, stride_w);
787 cmp(iter_kw, 0);
788 jg(kw_label, T_NEAR);
789 }
790
791 add(aux_reg_kernel, kw * ch_blk * stride_h * sizeof(float));
792 sub(aux_reg_ddst, ow * sp_step * sizeof(float));
793
794 sub(iter_kh, stride_h);
795 cmp(iter_kh, 0);
796 jg(kh_label, T_NEAR);
797 }
798
799 L(iter_exit_label);
800}
801
802template <cpu_isa_t isa>
803inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc(
804 int ur_ch_blocks, int ur_str_w, bool is_last_ch) {
805 int ch_block = jcp.ch_block;
806 int iw = jcp.iw;
807 int ih = jcp.ih;
808 int stride_w = jcp.stride_w;
809
810 const auto dsrc_layout_nxc = is_dsrc_layout_nxc();
811 const size_t ch_block_step = ch_block * (dsrc_layout_nxc ? 1 : ih * iw);
812 const size_t sp_step
813 = dsrc_layout_nxc ? jcp.ngroups : ch_block; // spatial step
814
815 for (int r = 0; r < reg_repeats_; r++) {
816 for (int ch = 0; ch < ur_ch_blocks; ch++) {
817 bool last_block = is_last_ch && ch == ur_ch_blocks - 1;
818 bool masked_store = last_block
819 && IMPLICATION(isa == sse41, tail_simd_overlap(r + 1));
820
821 // sse41: if second simd_w is outside channel_block, skip
822 if (last_block && tail_simd_overlap(r)) break;
823
824 for (int w = 0; w < ur_str_w; w++) {
825 size_t sp_offset = w * stride_w * sp_step;
826 size_t ch_offset = ch * ch_block_step + r * simd_w_;
827 size_t dsrc_off = static_cast<size_t>(
828 (sp_offset + ch_offset) * sizeof(float));
829
830 Vmm vmm_acc
831 = get_acc_reg((r * ur_ch_blocks + ch) * ur_str_w + w);
832 store_vmm(vmm_acc, ptr[reg_dsrc + dsrc_off], masked_store);
833 }
834 }
835 }
836}
837
838template <cpu_isa_t isa>
839inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::ch_loop_body(
840 int ur_ch_blocks, int unroll_w) {
841
842 auto call_compute_body
843 = [&](int ur_ch_blocks, int unroll_w, bool is_last_ch = false) {
844 mov(aux_reg_ddst, reg_ddst);
845 mov(aux_reg_kernel, reg_kernel);
846
847 load_ddst(ur_ch_blocks, unroll_w);
848 apply_filter(ur_ch_blocks, unroll_w, is_last_ch);
849 store_dsrc(ur_ch_blocks, unroll_w, is_last_ch);
850 };
851
852 const bool write_ch_loop = ur_ch_blocks > jcp.nb_ch_blocking;
853 if (write_ch_loop) {
854 assert(is_ddst_layout_nxc());
855
856 Label ch_loop_label, ch_tail_label, skip_ch_tail_label;
857 const int nb_oc = jcp.oc / jcp.ch_block;
858 const int ch_block_tail
859 = jcp.nb_ch - (utils::rnd_dn(nb_oc, jcp.nb_ch_blocking));
860 const int ch_step = jcp.nb_ch_blocking * jcp.ch_block;
861
862 const size_t wei_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.kh
863 * jcp.kw * jcp.ch_block * sizeof(float);
864 const size_t data_ch_stride
865 = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float);
866
867 mov(aux_reg_ch_blocks, reg_ch_blocks);
868 push(reg_dsrc);
869 push(reg_ddst);
870 push(reg_kernel);
871
872 if (nb_oc >= jcp.nb_ch_blocking) {
873 if (ch_block_tail) {
874 cmp(aux_reg_ch_blocks, jcp.nb_ch_blocking * jcp.ch_block);
875 jl(ch_tail_label, T_NEAR);
876 }
877
878 L(ch_loop_label);
879 {
880 call_compute_body(jcp.nb_ch_blocking, unroll_w);
881
882 add(reg_kernel, wei_ch_stride);
883 add(reg_dsrc, data_ch_stride);
884 add(reg_ddst, data_ch_stride);
885
886 sub(aux_reg_ch_blocks, ch_step);
887 cmp(aux_reg_ch_blocks, ch_step);
888 jge(ch_loop_label, T_NEAR);
889 }
890 }
891
892 if (ch_block_tail) {
893 // ch work range [1, jcp.nb_ch_blocking * ch_block)
894 L(ch_tail_label);
895 cmp(aux_reg_ch_blocks, 0);
896 jle(skip_ch_tail_label, T_NEAR);
897 call_compute_body(ch_block_tail, unroll_w, jcp.ch_tail > 0);
898 L(skip_ch_tail_label);
899 }
900
901 pop(reg_kernel);
902 pop(reg_ddst);
903 pop(reg_dsrc);
904
905 } else {
906 call_compute_body(ur_ch_blocks, unroll_w, jcp.ch_tail > 0);
907 }
908}
909
910template <cpu_isa_t isa>
911inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::unroll_width_body(
912 int ur_ch_blocks) {
913 assert(is_dsrc_layout_nxc() == is_ddst_layout_nxc());
914 const size_t ch_step = sizeof(float)
915 * (is_ddst_layout_nxc() ? jcp.ngroups : jcp.ch_block);
916
917 auto unroll_width_loop = [&](int unroll_w) {
918 Label unroll_w_label, skip_compute_label;
919 L(unroll_w_label);
920 {
921 cmp(reg_ur_str_w, unroll_w);
922 jl(skip_compute_label, T_NEAR);
923
924 ch_loop_body(ur_ch_blocks, unroll_w);
925
926 add(reg_dsrc, unroll_w * jcp.stride_w * ch_step);
927 add(reg_ddst, unroll_w * ch_step);
928
929 sub(reg_ur_str_w, unroll_w);
930 jmp(unroll_w_label);
931 }
932 L(skip_compute_label);
933 };
934
935 unroll_width_loop(jcp.ur_w);
936
937 unroll_width_loop(1);
938}
939
940template <cpu_isa_t isa>
941void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() {
942 preamble();
943
944 mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
945 mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
946 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
947 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
948 mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
949 mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
950 mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]);
951
952 if (is_dsrc_layout_nxc()) {
953 if (isa == avx512_core && (jcp.ch_tail > 0)) {
954 Label masking_done;
955 const size_t channel_step = jcp.nb_ch_blocking * jcp.ch_block;
956 kxnorw(k_ch_tail_mask, k_ch_tail_mask,
957 k_ch_tail_mask); // dummy mask all 1's
958 cmp(reg_ch_blocks, channel_step);
959 je(masking_done, T_NEAR);
960 // Prepare masks for tail
961 Reg32 reg_tmp_32 = reg_tmp.cvt32();
962 mov(reg_tmp_32, (1 << jcp.ch_tail) - 1);
963 kmovw(k_ch_tail_mask, reg_tmp_32);
964 L(masking_done);
965 }
966
967 unroll_width_body(jcp.nb_ch);
968 } else {
969
970 auto ch_blocks_loop = [&](int ch_blocks) {
971 Label skip_loop_label;
972 cmp(reg_ch_blocks, ch_blocks * jcp.ch_block);
973 jl(skip_loop_label, T_NEAR);
974 unroll_width_body(ch_blocks);
975 L(skip_loop_label);
976 };
977
978 ch_blocks_loop(jcp.nb_ch_blocking);
979
980 int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
981 if (ch_blocks_tail) { ch_blocks_loop(ch_blocks_tail); }
982 }
983
984 this->postamble();
985}
986#undef GET_OFF
987
988template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_core>;
989template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>;
990template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse41>;
991
992#define GET_OFF(field) offsetof(jit_dw_conv_call_s, field)
993
994template <cpu_isa_t isa>
995inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_xmm(
996 Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) {
997 int ch_tail = jcp.oc_without_padding % simd_w_; // special case for SSE41
998 int bytes
999 = (compute_tail && ch_tail > 0 ? ch_tail : simd_w_) * sizeof(float);
1000 load_bytes(vmm, addr, bytes);
1001}
1002template <>
1003inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>::load_xmm(
1004 Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) {
1005 int bytes = (compute_tail ? jcp.ch_tail : jcp.ch_block) * sizeof(float);
1006 load_bytes(vmm, addr, bytes);
1007}
1008template <>
1009inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_core>::load_xmm(
1010 Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) {
1011 Zmm masked_vmm = compute_tail ? vmm | k_ch_tail_mask | T_z : vmm;
1012 vmovups(masked_vmm, addr);
1013}
1014
1015template <cpu_isa_t isa>
1016inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_xmm(
1017 Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) {
1018 int ch_tail = jcp.oc_without_padding % simd_w_; // special case for SSE41
1019 int bytes
1020 = (compute_tail && ch_tail > 0 ? ch_tail : simd_w_) * sizeof(float);
1021 store_bytes(vmm, addr, bytes);
1022}
1023template <>
1024inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>::store_xmm(
1025 Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) {
1026 int bytes = (compute_tail ? jcp.ch_tail : jcp.ch_block) * sizeof(float);
1027 store_bytes(vmm, addr, bytes);
1028}
1029template <>
1030inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_core>::store_xmm(
1031 Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) {
1032 Zmm masked_vmm = compute_tail ? vmm | k_ch_tail_mask : vmm;
1033 vmovups(addr, masked_vmm);
1034}
1035
1036template <cpu_isa_t isa>
1037inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::addps_xmm(Vmm &vmm_dst,
1038 Vmm &vmm_src, const Xbyak::Address &addr, bool compute_tail) {
1039 load_xmm(vmm_src, addr, compute_tail);
1040 uni_vaddps(vmm_dst, vmm_dst, vmm_src);
1041}
1042template <>
1043inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>::addps_xmm(
1044 Vmm &vmm_dst, Vmm &vmm_src, const Xbyak::Address &addr,
1045 bool compute_tail) {
1046 if (compute_tail) {
1047 load_xmm(vmm_src, addr, true);
1048 uni_vaddps(vmm_dst, vmm_dst, vmm_src);
1049 } else {
1050 assert(vmm_dst.getIdx() == vmm_src.getIdx());
1051 uni_vaddps(vmm_dst, vmm_src, addr);
1052 }
1053}
1054template <>
1055inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_core>::addps_xmm(
1056 Vmm &vmm_dst, Vmm &vmm_src, const Xbyak::Address &addr,
1057 bool compute_tail) {
1058 Zmm masked_vmm = compute_tail ? vmm_src | k_ch_tail_mask | T_z : vmm_src;
1059 vaddps(vmm_dst, masked_vmm, addr);
1060}
1061
1062template <cpu_isa_t isa>
1063inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter() {
1064 for (int ch = 0; ch < jcp.nb_ch_blocking; ++ch) {
1065 for (int r = 0; r < reg_repeats_; ++r) {
1066 for (int i = 0; i < jcp.kw; ++i) {
1067 Vmm vmm_acc
1068 = get_acc_reg(r * jcp.kw + i * jcp.nb_ch_blocking + ch);
1069 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
1070 }
1071 }
1072 }
1073}
1074
1075template <>
1076inline void jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::load_filter(
1077 int nb_ch_blocking, bool is_last_ch) {
1078 assert(nb_ch_blocking == 1);
1079 for (int r = 0; r < reg_repeats_; ++r) {
1080 bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail;
1081 bool masked_load = tail_in_first_simd && is_last_ch;
1082 const int reg_set = r * jcp.kw;
1083 for (int i = 0; i < jcp.kw; ++i) {
1084 size_t off_filter = static_cast<size_t>(
1085 (i * jcp.ch_block + r * simd_w_) * sizeof(float));
1086 Vmm vmm_acc = get_acc_reg(reg_set + i);
1087 load_xmm(
1088 vmm_acc, vmmword[reg_tmp_filter + off_filter], masked_load);
1089 }
1090 if (masked_load) break; // if tail falls under first simd, skip
1091 }
1092}
1093
1094template <cpu_isa_t isa>
1095inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_filter(
1096 int nb_ch_blocking, bool is_last_ch) {
1097 const size_t filter_step = jcp.kh * jcp.kw;
1098 for (int ch = 0; ch < nb_ch_blocking; ++ch) {
1099 bool masked_load = is_last_ch && (ch == nb_ch_blocking - 1);
1100 for (int i = 0; i < jcp.kw; ++i) {
1101 size_t off_filter = static_cast<size_t>(
1102 (ch * filter_step + i) * jcp.ch_block * sizeof(float));
1103 Vmm vmm_acc = get_acc_reg(i * jcp.nb_ch_blocking + ch);
1104 load_xmm(
1105 vmm_acc, vmmword[reg_tmp_filter + off_filter], masked_load);
1106 }
1107 }
1108}
1109
1110template <cpu_isa_t isa>
1111inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_bias() {
1112 for (int ch = 0; ch < jcp.nb_ch_blocking; ++ch) {
1113 for (int r = 0; r < reg_repeats_; ++r) {
1114 Vmm vmm_bias = get_bias_reg(r * jcp.nb_ch_blocking + ch);
1115 uni_vpxor(vmm_bias, vmm_bias, vmm_bias);
1116 }
1117 }
1118}
1119
1120template <>
1121inline void jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::load_bias(
1122 int nb_ch_blocking, bool is_last_ch) {
1123 for (int r = 0; r < reg_repeats_; ++r) {
1124 bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail;
1125 bool masked_load = tail_in_first_simd && is_last_ch;
1126 size_t half_ch_block_offset
1127 = static_cast<size_t>(r * simd_w_ * sizeof(float));
1128 Vmm vmm_bias = get_bias_reg(r);
1129 load_xmm(vmm_bias, vmmword[reg_bias_baddr + half_ch_block_offset],
1130 masked_load);
1131 if (masked_load) break; // if tail falls under first simd, skip
1132 }
1133}
1134
1135template <cpu_isa_t isa>
1136inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_bias(
1137 int nb_ch_blocking, bool is_last_ch) {
1138 for (int ch = 0; ch < nb_ch_blocking; ++ch) {
1139 bool masked_load = is_last_ch && (ch == nb_ch_blocking - 1);
1140 size_t bias_offset
1141 = static_cast<size_t>(ch * jcp.ch_block * sizeof(float));
1142 Vmm vmm_bias = get_bias_reg(ch);
1143 load_xmm(vmm_bias, vmmword[reg_bias_baddr + bias_offset], masked_load);
1144 }
1145}
1146
1147template <cpu_isa_t isa>
1148inline void
1149jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_unroll_ow_step_nxc(
1150 int unroll_w, int l_pad, int pad_offset, int ow_block,
1151 int nb_ch_blocking, bool is_last_ch) {
1152
1153 assert(one_of(isa, avx2, avx512_core));
1154
1155 const size_t ch_step = jcp.ngroups;
1156 const int iw_block = ow_block * jcp.stride_w;
1157 const int right_border = jcp.iw - iw_block;
1158 const int r_pad = jcp.r_pad;
1159 const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
1160
1161 /* preamble count for number of cascaded LOAD + FMA operation */
1162 const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
1163 const bool is_last_block = (unroll_w + ow_block == jcp.ow);
1164
1165 /* LOAD initial input registers, then cascade LOADs and FMAs*/
1166 for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
1167 int output_sp_offset = i_ur * ch_step;
1168 if (i_ur == 0) {
1169 for (int c = 0; c < input_overlap; ++c) {
1170 int input_sp = c - pad_offset;
1171 int input_sp_offset = input_sp * ch_step;
1172 if (input_sp_offset < 0 && unroll_w == jcp.ow) continue;
1173
1174 const bool over_steps_bdry = true && is_last_block
1175 && (c - pad_offset + r_pad > right_border);
1176 if (over_steps_bdry) continue;
1177
1178 for (int ch = 0; ch < nb_ch_blocking; ++ch) {
1179 bool masked_load = is_last_ch && ch == nb_ch_blocking - 1;
1180 size_t input_offset = static_cast<size_t>(
1181 (input_sp_offset + ch * simd_w_) * sizeof(float));
1182 Vmm vmm_input = get_input_reg(
1183 (c % jcp.kw) * jcp.nb_ch_blocking + ch);
1184 load_xmm(vmm_input, ptr[reg_tmp_input + input_offset],
1185 masked_load);
1186 }
1187 }
1188 } else {
1189 for (int c = 0; c < cascade_input; ++c) {
1190 int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
1191 int input_sp = overlap + c - pad_offset;
1192 int input_sp_offset = input_sp * ch_step;
1193 if (input_sp_offset < 0 || overlap + c + l_pad > right_border)
1194 continue;
1195
1196 const bool over_steps_bdry = true && is_last_block
1197 && (overlap + c - pad_offset + r_pad > right_border);
1198 if (over_steps_bdry) continue;
1199
1200 for (int ch = 0; ch < nb_ch_blocking; ++ch) {
1201 bool masked_load = is_last_ch && ch == nb_ch_blocking - 1;
1202 size_t input_offset = static_cast<size_t>(
1203 (input_sp_offset + ch * simd_w_) * sizeof(float));
1204 Vmm vmm_input = get_input_reg(
1205 ((overlap + c) % jcp.kw) * jcp.nb_ch_blocking + ch);
1206 load_xmm(vmm_input, ptr[reg_tmp_input + input_offset],
1207 masked_load);
1208 }
1209 }
1210 }
1211 for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) {
1212 int io_overlap = i_kw + (i_ur * jcp.stride_w);
1213
1214 /* Don't apply FMAs that fall into the padded region */
1215 if (io_overlap - l_pad < 0
1216 || io_overlap - jcp.l_pad >= right_border)
1217 continue;
1218
1219 const bool over_steps_bdry = is_last_block
1220 && (io_overlap - jcp.l_pad + jcp.r_pad > right_border);
1221 if (over_steps_bdry) continue;
1222
1223 for (int ch = 0; ch < nb_ch_blocking; ++ch) {
1224 bool masked_load = is_last_ch && ch == nb_ch_blocking - 1;
1225 size_t output_offset = static_cast<size_t>(
1226 (output_sp_offset + ch * simd_w_) * sizeof(float));
1227
1228 Vmm vmm_input = get_input_reg(
1229 ((io_overlap - l_pad) % jcp.kw) * jcp.nb_ch_blocking
1230 + ch);
1231 Vmm vmm_acc = get_acc_reg(i_kw * jcp.nb_ch_blocking + ch);
1232 if (masked_load) {
1233 Vmm vmm_output = get_output_reg(0);
1234 load_xmm(vmm_output, ptr[reg_tmp_output + output_offset],
1235 true);
1236 uni_vfmadd231ps(vmm_acc, vmm_input, vmm_output);
1237 } else {
1238 uni_vfmadd231ps(vmm_acc, vmm_input,
1239 ptr[reg_tmp_output + output_offset]);
1240 }
1241 }
1242 }
1243 }
1244}
1245
1246template <cpu_isa_t isa>
1247inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_unroll_ow_step(
1248 int unroll_w, int l_pad, int pad_offset, int ow_block,
1249 bool is_last_ch) {
1250
1251 const size_t ch_step = is_layout_nxc() ? jcp.ngroups : simd_w_;
1252 const int iw_block = ow_block * jcp.stride_w;
1253 const int right_border = jcp.iw - iw_block;
1254 const int r_pad = jcp.r_pad;
1255 const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
1256
1257 /* preamble count for number of cascaded LOAD + FMA operation */
1258 const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
1259 const bool is_last_block = (unroll_w + ow_block == jcp.ow);
1260 const bool nxc_sse41_offset = is_layout_nxc() && isa == sse41;
1261
1262 /* LOAD initial input registers, then cascade LOADs and FMAs*/
1263 for (int r = 0; r < reg_repeats_; ++r) {
1264 bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail;
1265 bool masked_load
1266 = IMPLICATION(isa == sse41, tail_in_first_simd) && is_last_ch;
1267 for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
1268 int output_sp_offset = nxc_sse41_offset
1269 ? i_ur * ch_step + r * simd_w_
1270 : (i_ur * reg_repeats_ + r) * ch_step;
1271 size_t output_offset
1272 = static_cast<size_t>(output_sp_offset * sizeof(float));
1273 Vmm vmm_output = get_output_reg(r);
1274 load_xmm(vmm_output, ptr[reg_tmp_output + output_offset],
1275 masked_load);
1276 if (i_ur == 0) {
1277 for (int c = 0; c < input_overlap; ++c) {
1278 int input_sp = c - pad_offset;
1279 int input_sp_offset = nxc_sse41_offset
1280 ? input_sp * ch_step + r * simd_w_
1281 : (input_sp * reg_repeats_ + r) * ch_step;
1282 if (input_sp_offset < 0 && unroll_w == jcp.ow) continue;
1283
1284 const bool over_steps_bdry = true && is_last_block
1285 && (c - pad_offset + r_pad > right_border);
1286 if (over_steps_bdry) continue;
1287
1288 size_t input_offset = static_cast<size_t>(
1289 input_sp_offset * sizeof(float));
1290 Vmm vmm_input
1291 = get_input_reg((c % jcp.kw) * reg_repeats_ + r);
1292 load_xmm(vmm_input, ptr[reg_tmp_input + input_offset],
1293 masked_load);
1294 }
1295 } else {
1296 for (int c = 0; c < cascade_input; ++c) {
1297 int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
1298 int input_sp = overlap + c - pad_offset;
1299 int input_sp_offset = nxc_sse41_offset
1300 ? input_sp * ch_step + r * simd_w_
1301 : (input_sp * reg_repeats_ + r) * ch_step;
1302 if (input_sp_offset < 0
1303 || overlap + c + l_pad > right_border)
1304 continue;
1305
1306 const bool over_steps_bdry = true && is_last_block
1307 && (overlap + c - pad_offset + r_pad
1308 > right_border);
1309 if (over_steps_bdry) continue;
1310
1311 size_t input_offset = static_cast<size_t>(
1312 input_sp_offset * sizeof(float));
1313 Vmm vmm_input = get_input_reg(
1314 ((overlap + c) % jcp.kw) * reg_repeats_ + r);
1315 load_xmm(vmm_input, ptr[reg_tmp_input + input_offset],
1316 masked_load);
1317 }
1318 }
1319 for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) {
1320 int io_overlap = i_kw + (i_ur * jcp.stride_w);
1321
1322 /* Don't apply FMAs that fall into the padded region */
1323 if (io_overlap - l_pad < 0
1324 || io_overlap - jcp.l_pad >= right_border)
1325 continue;
1326
1327 const bool over_steps_bdry = is_last_block
1328 && (io_overlap - jcp.l_pad + jcp.r_pad > right_border);
1329 if (over_steps_bdry) continue;
1330
1331 Vmm vmm_input = get_input_reg(
1332 ((io_overlap - l_pad) % jcp.kw) * reg_repeats_ + r);
1333 Vmm vmm_acc = get_acc_reg(r * jcp.kw + i_kw);
1334 Vmm vmm_aux = isa == sse41 ? get_aux_reg() : vmm_input;
1335 if (isa == sse41) uni_vmovups(vmm_aux, vmm_input);
1336 uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output);
1337 }
1338 }
1339 if (isa == sse41 && masked_load)
1340 break; // if tail falls under first simd, skip
1341 }
1342}
1343
1344template <cpu_isa_t isa>
1345inline void
1346jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::dispatch_ow_step_unroll(
1347 int unroll_w, int l_pad, int pad_offset, int ow_block,
1348 int nb_ch_blocking, bool is_last_ch) {
1349 if (jcp.is_fast_depthwise) {
1350 compute_unroll_ow_step_nxc(unroll_w, l_pad, pad_offset, ow_block,
1351 nb_ch_blocking, is_last_ch);
1352 } else {
1353 assert(nb_ch_blocking == 1);
1354 compute_unroll_ow_step(
1355 unroll_w, l_pad, pad_offset, ow_block, is_last_ch);
1356 }
1357}
1358
1359template <>
1360inline void
1361jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::compute_bias_step_unroll(
1362 const int unroll_w, int nb_ch_blocking, bool is_last_ch) {
1363 const int ch_step = is_ddst_layout_nxc() ? jcp.ngroups : simd_w_;
1364 for (int r = 0; r < reg_repeats_; ++r) {
1365 bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail;
1366 bool masked_load = tail_in_first_simd && is_last_ch;
1367 for (int i = 0; i < unroll_w; ++i) {
1368 int off_output = is_ddst_layout_nxc()
1369 ? i * ch_step + r * simd_w_
1370 : (i * reg_repeats_ + r) * ch_step;
1371 Vmm vmm_bias = get_bias_reg(r);
1372 Vmm vmm_out = get_output_reg(1 + r);
1373 addps_xmm(vmm_bias, vmm_out,
1374 vmmword[reg_tmp_output + off_output * sizeof(float)],
1375 masked_load);
1376 }
1377 if (masked_load) break; // if tail falls under first simd, skip
1378 }
1379}
1380
1381template <cpu_isa_t isa>
1382inline void
1383jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_step_unroll(
1384 const int unroll_w, int nb_ch_blocking, bool is_last_ch) {
1385 const int ch_step = is_ddst_layout_nxc() ? jcp.ngroups : simd_w_;
1386 for (int i = 0; i < unroll_w; ++i) {
1387 for (int ch = 0; ch < nb_ch_blocking; ++ch) {
1388 Vmm vmm_bias = get_bias_reg(ch);
1389 size_t off_output = static_cast<size_t>(
1390 (i * ch_step + ch * simd_w_) * sizeof(float));
1391 bool masked_store = is_last_ch && (ch == nb_ch_blocking - 1);
1392 bool use_extra_vmm = isa == avx2 && masked_store;
1393 Vmm vmm_out = use_extra_vmm ? get_output_reg(1) : vmm_bias;
1394 addps_xmm(vmm_bias, vmm_out, vmmword[reg_tmp_output + off_output],
1395 masked_store);
1396 }
1397 }
1398}
1399
1400template <>
1401inline void jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::store_filter(
1402 int nb_ch_blocking, bool is_last_ch) {
1403 assert(nb_ch_blocking == 1);
1404 for (int r = 0; r < reg_repeats_; ++r) {
1405 bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail;
1406 bool masked_load = tail_in_first_simd && is_last_ch;
1407 const int reg_set = r * jcp.kw;
1408 for (int i = 0; i < jcp.kw; ++i) {
1409 size_t off_filter = static_cast<size_t>(
1410 (i * jcp.ch_block + r * simd_w_) * sizeof(float));
1411 Vmm vmm_acc = get_acc_reg(i + reg_set);
1412 store_xmm(
1413 vmm_acc, vmmword[reg_tmp_filter + off_filter], masked_load);
1414 }
1415 if (masked_load) break; // if tail falls under first simd, skip
1416 }
1417}
1418
1419template <cpu_isa_t isa>
1420inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_filter(
1421 int nb_ch_blocking, bool is_last_ch) {
1422 size_t filter_step = jcp.kh * jcp.kw;
1423 for (int ch = 0; ch < nb_ch_blocking; ++ch) {
1424 bool masked_store = is_last_ch && ch == nb_ch_blocking - 1;
1425 for (int i = 0; i < jcp.kw; ++i) {
1426 size_t off_filter = static_cast<size_t>(
1427 (ch * filter_step + i) * jcp.ch_block * sizeof(float));
1428 Vmm vmm_acc = get_acc_reg(i * jcp.nb_ch_blocking + ch);
1429 store_xmm(vmm_acc, vmmword[reg_tmp_filter + off_filter],
1430 masked_store);
1431 }
1432 }
1433}
1434
1435template <>
1436inline void jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::store_bias(
1437 int nb_ch_blocking, bool is_last_ch) {
1438 for (int r = 0; r < reg_repeats_; ++r) {
1439 bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail;
1440 bool masked_load = tail_in_first_simd && is_last_ch;
1441 size_t half_ch_block_offset
1442 = static_cast<size_t>(r * simd_w_ * sizeof(float));
1443 Vmm vmm_bias = get_bias_reg(r);
1444 store_xmm(vmm_bias, vmmword[reg_bias_baddr + half_ch_block_offset],
1445 masked_load);
1446 if (masked_load) break; // if tail falls under first simd, skip
1447 }
1448}
1449
1450template <cpu_isa_t isa>
1451inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_bias(
1452 int nb_ch_blocking, bool is_last_ch) {
1453 for (int ch = 0; ch < nb_ch_blocking; ++ch) {
1454 bool masked_store = is_last_ch && ch == nb_ch_blocking - 1;
1455 size_t bias_offset = static_cast<size_t>(ch * simd_w_ * sizeof(float));
1456 Vmm vmm_bias = get_bias_reg(ch);
1457 store_xmm(
1458 vmm_bias, vmmword[reg_bias_baddr + bias_offset], masked_store);
1459 }
1460}
1461
1462template <cpu_isa_t isa>
1463inline void
1464jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_spatial_loop_bias(
1465 int nb_ch_blocking, bool is_last_ch) {
1466 Label oh_label;
1467 Label ow_blk_label;
1468
1469 const int unroll_w = nstl::min(max_unroll_w_, jcp.ow);
1470 const int unroll_w_trips = jcp.ow / unroll_w;
1471 const int tail_w = jcp.ow > max_unroll_w_ ? jcp.ow % max_unroll_w_ : 0;
1472
1473 const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block;
1474 const size_t ch_offset = ch_step * sizeof(float);
1475
1476 mov(reg_oh, ptr[this->param1 + GET_OFF(oh_index)]);
1477 mov(reg_oh_worksize, ptr[this->param1 + GET_OFF(oh_count)]);
1478
1479 mov(reg_tmp_output, reg_output_baddr);
1480 L(oh_label);
1481 {
1482
1483 mov(reg_iter_ow_blk, unroll_w_trips);
1484 L(ow_blk_label);
1485 {
1486 compute_bias_step_unroll(unroll_w, nb_ch_blocking, is_last_ch);
1487 add(reg_tmp_output, unroll_w * ch_offset);
1488
1489 dec(reg_iter_ow_blk);
1490 cmp(reg_iter_ow_blk, 0);
1491 jg(ow_blk_label, T_NEAR);
1492 }
1493
1494 if (tail_w > 0) {
1495 compute_bias_step_unroll(tail_w, nb_ch_blocking, is_last_ch);
1496 add(reg_tmp_output, tail_w * ch_offset);
1497 }
1498
1499 inc(reg_oh);
1500 cmp(reg_oh, reg_oh_worksize);
1501 jl(oh_label, T_NEAR);
1502 }
1503}
1504
1505template <cpu_isa_t isa>
1506void jit_uni_dw_conv_bwd_weights_kernel_f32<
1507 isa>::compute_single_ch_block_bias() {
1508
1509 auto write_compute_bias = [&](bool is_last_ch) {
1510 Label skip_load_bias;
1511
1512 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1513 and_(reg_exec_flags, FLAG_ZERO_BIAS);
1514 test(reg_exec_flags, reg_exec_flags);
1515 jne(skip_load_bias);
1516
1517 assert(jcp.nb_ch_blocking == 1);
1518 load_bias(jcp.nb_ch_blocking, is_last_ch);
1519
1520 L(skip_load_bias);
1521 compute_spatial_loop_bias(jcp.nb_ch_blocking, is_last_ch);
1522
1523 store_bias(jcp.nb_ch_blocking, is_last_ch);
1524 };
1525
1526 Label skip_masked_bias_label, done_bias_label;
1527
1528 zero_bias();
1529
1530 bool do_bias_ch_tail = jcp.ch_tail > 0;
1531 if (do_bias_ch_tail) {
1532 // test last channel
1533 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1534 and_(reg_exec_flags, FLAG_OC_LAST);
1535 test(reg_exec_flags, reg_exec_flags);
1536 jz(skip_masked_bias_label, T_NEAR);
1537
1538 write_compute_bias(true);
1539
1540 jmp(done_bias_label, T_NEAR);
1541 L(skip_masked_bias_label);
1542 }
1543
1544 write_compute_bias(false);
1545
1546 L(done_bias_label);
1547}
1548
1549template <cpu_isa_t isa>
1550void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ch_loop_bias(
1551 bool do_load_bias) {
1552
1553 assert(is_ddst_layout_nxc());
1554
1555 auto write_compute_bias = [&](int nb_ch_blocking, bool is_last_ch) {
1556 if (do_load_bias)
1557 load_bias(nb_ch_blocking, is_last_ch);
1558 else
1559 zero_bias();
1560 compute_spatial_loop_bias(nb_ch_blocking, is_last_ch);
1561 store_bias(nb_ch_blocking, is_last_ch);
1562 };
1563
1564 if (jcp.nb_ch > jcp.nb_ch_blocking) {
1565
1566 Label ch_loop_label;
1567 const bool masked_ch_tail = jcp.ch_tail > 0;
1568 const int nb_ch_blocking_tail = jcp.nb_ch % jcp.nb_ch_blocking;
1569 const bool unroll_last_ch_block
1570 = nb_ch_blocking_tail > 0 || masked_ch_tail;
1571 const int last_ch_block = nb_ch_blocking_tail > 0 ? nb_ch_blocking_tail
1572 : jcp.nb_ch_blocking;
1573
1574 push(reg_output_baddr);
1575
1576 Label last_ch_block_label, ch_block_done_label;
1577 if (unroll_last_ch_block) {
1578 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1579 and_(reg_exec_flags, FLAG_OC_LAST);
1580 test(reg_exec_flags, reg_exec_flags);
1581 jnz(last_ch_block_label, T_NEAR);
1582 }
1583
1584 write_compute_bias(jcp.nb_ch_blocking, false);
1585
1586 if (unroll_last_ch_block) {
1587 jmp(ch_block_done_label, T_NEAR);
1588
1589 L(last_ch_block_label);
1590 write_compute_bias(last_ch_block, masked_ch_tail);
1591 L(ch_block_done_label);
1592 }
1593
1594 pop(reg_output_baddr);
1595
1596 } else {
1597 bool masked_ch_tail = jcp.ch_tail > 0;
1598 write_compute_bias(jcp.nb_ch_blocking, masked_ch_tail);
1599 }
1600}
1601
1602template <cpu_isa_t isa>
1603void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::deploy_ch_loop_bias() {
1604
1605 Label ch_loop_label, zero_bias_label, load_bias_done_label;
1606
1607 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1608 and_(reg_exec_flags, FLAG_ZERO_BIAS);
1609 test(reg_exec_flags, reg_exec_flags);
1610 jne(zero_bias_label, T_NEAR);
1611
1612 compute_ch_loop_bias(true); // load_bias
1613 jmp(load_bias_done_label, T_NEAR);
1614
1615 L(zero_bias_label);
1616 compute_ch_loop_bias(false); // zero_bias
1617
1618 L(load_bias_done_label);
1619}
1620
1621template <cpu_isa_t isa>
1622inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias() {
1623
1624 mov(reg_bias_baddr, ptr[this->param1 + GET_OFF(bias)]);
1625
1626 if (is_ddst_layout_nxc())
1627 deploy_ch_loop_bias();
1628 else
1629 compute_single_ch_block_bias();
1630}
1631
1632template <cpu_isa_t isa>
1633inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter_kh_loop(
1634 int nb_ch_blocking) {
1635
1636 const size_t filter_offset_kw = jcp.kw * jcp.ch_block * sizeof(float);
1637 const size_t filter_offset_kh = jcp.kh * filter_offset_kw;
1638
1639 Label kh_loop_label;
1640
1641 mov(reg_kh_aux, jcp.kh);
1642 L(kh_loop_label);
1643 {
1644 store_filter(nb_ch_blocking);
1645
1646 add(reg_tmp_filter, filter_offset_kw);
1647 dec(reg_kh_aux);
1648 cmp(reg_kh_aux, 0);
1649 jg(kh_loop_label, T_NEAR);
1650 }
1651
1652 /* Comeback pointers */
1653 sub(reg_tmp_filter, filter_offset_kh);
1654}
1655
1656template <cpu_isa_t isa>
1657inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter_ch_loop() {
1658
1659 bool write_ch_blocking_unroll
1660 = is_layout_nxc() && jcp.nb_ch > jcp.nb_ch_blocking;
1661 if (write_ch_blocking_unroll) {
1662 const int nb_ch_blocking_tail = jcp.nb_ch % jcp.nb_ch_blocking;
1663
1664 Label last_ch_block_label, ch_block_done_label;
1665
1666 if (nb_ch_blocking_tail) {
1667 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1668 and_(reg_exec_flags, FLAG_OC_LAST);
1669 test(reg_exec_flags, reg_exec_flags);
1670 jnz(last_ch_block_label, T_NEAR);
1671 }
1672
1673 zero_filter_kh_loop(jcp.nb_ch_blocking);
1674
1675 if (nb_ch_blocking_tail) {
1676 jmp(ch_block_done_label, T_NEAR);
1677
1678 L(last_ch_block_label);
1679 zero_filter_kh_loop(nb_ch_blocking_tail);
1680 L(ch_block_done_label);
1681 }
1682 } else {
1683 zero_filter_kh_loop(jcp.nb_ch_blocking);
1684 }
1685}
1686
1687template <cpu_isa_t isa>
1688inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::deploy_zero_filter() {
1689
1690 Label skip_zeroing_label;
1691
1692 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1693 and_(reg_exec_flags, FLAG_ZERO_FILTER);
1694 test(reg_exec_flags, reg_exec_flags);
1695 je(skip_zeroing_label, T_NEAR);
1696
1697 zero_filter();
1698
1699 mov(reg_tmp_filter, reg_filter_baddr);
1700 zero_filter_ch_loop();
1701
1702 L(skip_zeroing_label);
1703}
1704
1705template <cpu_isa_t isa>
1706inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_kh_step(
1707 int unroll_w, int l_pad, int pad_offset, int ow_block,
1708 int nb_ch_blocking, bool is_last_ch) {
1709
1710 const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block;
1711 const size_t input_offset = jcp.iw * ch_step * sizeof(float);
1712 const size_t filter_offset = jcp.kw * jcp.ch_block * sizeof(float);
1713
1714 Label kh_loop_label, skip_loop_label;
1715
1716 cmp(reg_kh, 0);
1717 je(skip_loop_label, T_NEAR);
1718
1719 mov(reg_kh_aux, reg_kh);
1720 L(kh_loop_label);
1721 {
1722 load_filter(nb_ch_blocking, is_last_ch);
1723 dispatch_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block,
1724 nb_ch_blocking, is_last_ch);
1725 store_filter(nb_ch_blocking, is_last_ch);
1726
1727 add(reg_tmp_filter, filter_offset);
1728 add(reg_tmp_input, input_offset);
1729 dec(reg_kh_aux);
1730 cmp(reg_kh_aux, 0);
1731 jg(kh_loop_label, T_NEAR);
1732 }
1733
1734 /* Comeback pointers */
1735 Label kh_comeback_label;
1736 mov(reg_kh_aux, reg_kh);
1737 L(kh_comeback_label);
1738 {
1739 sub(reg_tmp_input, input_offset);
1740 sub(reg_tmp_filter, filter_offset);
1741 dec(reg_kh_aux);
1742 cmp(reg_kh_aux, 0);
1743 jg(kh_comeback_label, T_NEAR);
1744 }
1745
1746 L(skip_loop_label);
1747}
1748
1749template <cpu_isa_t isa>
1750inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ch_loop(
1751 int unroll_w, int l_pad, int pad_offset, int ow_block) {
1752
1753 bool write_ch_blocking_unroll
1754 = is_layout_nxc() && jcp.nb_ch > jcp.nb_ch_blocking;
1755 if (write_ch_blocking_unroll) {
1756
1757 const bool masked_ch_tail = jcp.ch_tail > 0;
1758 const int nb_ch_blocking_tail = jcp.nb_ch % jcp.nb_ch_blocking;
1759 const int last_ch_block = nb_ch_blocking_tail > 0 ? nb_ch_blocking_tail
1760 : jcp.nb_ch_blocking;
1761 const bool unroll_last_ch_block
1762 = nb_ch_blocking_tail > 0 || masked_ch_tail;
1763
1764 Label last_ch_block_label, ch_block_done_label;
1765 if (unroll_last_ch_block) {
1766 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1767 and_(reg_exec_flags, FLAG_OC_LAST);
1768 test(reg_exec_flags, reg_exec_flags);
1769 jnz(last_ch_block_label, T_NEAR);
1770 }
1771
1772 compute_kh_step(unroll_w, l_pad, pad_offset, ow_block,
1773 jcp.nb_ch_blocking, false);
1774
1775 if (unroll_last_ch_block) {
1776 jmp(ch_block_done_label, T_NEAR);
1777
1778 L(last_ch_block_label);
1779 compute_kh_step(unroll_w, l_pad, pad_offset, ow_block,
1780 last_ch_block, masked_ch_tail);
1781 L(ch_block_done_label);
1782 }
1783 } else {
1784 bool masked_ch_tail = jcp.ch_tail > 0 && is_layout_nxc();
1785 compute_kh_step(unroll_w, l_pad, pad_offset, ow_block,
1786 jcp.nb_ch_blocking, masked_ch_tail);
1787 }
1788}
1789
1790template <cpu_isa_t isa>
1791inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_loop(
1792 int unroll_w, int l_pad, int pad_offset, int ow_block) {
1793
1794 mov(reg_tmp_output, reg_output_baddr);
1795 mov(reg_tmp_input, reg_input_baddr);
1796 mov(reg_tmp_filter, reg_filter_baddr);
1797
1798 const int input_bottom_padding_overlap
1799 = div_up(jcp.ih + jcp.t_pad - (jcp.kh - 1), jcp.stride_h);
1800
1801 const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block;
1802 const size_t typesize = sizeof(float);
1803 const size_t input_shift = typesize * jcp.iw * ch_step;
1804 const size_t output_shift = typesize * jcp.ow * ch_step;
1805 const size_t filter_shift = typesize * jcp.kw * jcp.ch_block;
1806
1807 Label loop_begin_label, loop_end_label, common_block_label,
1808 top_padding_end_label, bottom_padding_end_label,
1809 bottom_padding_label;
1810
1811 mov(reg_oh, ptr[this->param1 + GET_OFF(oh_index)]);
1812 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_count)]);
1813
1814 // replacement for 'os_index_end'
1815 mov(reg_oh_worksize, ptr[this->param1 + GET_OFF(oh_count)]);
1816
1817 cmp(reg_kh, 0);
1818 jle(loop_end_label, T_NEAR); // no iterations along kh
1819 cmp(reg_oh, reg_oh_worksize);
1820 jge(loop_end_label, T_NEAR); // no iterations along height dimension
1821
1822 L(loop_begin_label);
1823
1824 compute_ch_loop(unroll_w, l_pad, pad_offset, ow_block);
1825
1826 /* Compute 'top' edge */
1827 if (jcp.t_pad > 0) {
1828
1829 /* Check if within top padding region */
1830 cmp(reg_oh, div_up(jcp.t_pad, jcp.stride_h));
1831 jge(top_padding_end_label, T_NEAR);
1832
1833 /* Increment step counter and adjust filter position */
1834 sub(reg_tmp_filter, filter_shift * jcp.stride_h);
1835 add(reg_kh, jcp.stride_h);
1836
1837 /* Final number of kernel elements that overlap with input */
1838 const int inp_ker_overlap = nstl::min(jcp.kh, jcp.ih);
1839 cmp(reg_kh, inp_ker_overlap);
1840 jle(common_block_label, T_NEAR);
1841
1842 /* Correct any excess shifts to kernel and input */
1843 if (jcp.t_pad <= jcp.oh * jcp.stride_h) {
1844 /* Filter has moved beyond padding (adjust for stride effects) */
1845 if (jcp.t_pad % jcp.stride_h != 0) {
1846 int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h;
1847 add(reg_tmp_filter, filter_shift * inp_corr);
1848 add(reg_tmp_input, input_shift * inp_corr);
1849 }
1850 } else {
1851 /* Filter still overlaps padding (complete reset) */
1852 sub(reg_tmp_filter,
1853 (jcp.t_pad - jcp.oh * jcp.stride_h) * filter_shift);
1854 }
1855
1856 /* Apply correction */
1857 mov(reg_kh, inp_ker_overlap);
1858 jmp(common_block_label);
1859
1860 L(top_padding_end_label);
1861 }
1862
1863 /* Compute 'bottom' edge */
1864 if (jcp.b_pad > 0) {
1865
1866 /* Check if within bottom padding region */
1867 cmp(reg_oh, input_bottom_padding_overlap - 1);
1868 jl(bottom_padding_end_label, T_NEAR);
1869 jg(bottom_padding_label, T_NEAR);
1870
1871 /* Execute overlap correction between the filter and the initial
1872 * bottom padding region. */
1873 mov(reg_kh,
1874 jcp.ih + jcp.t_pad
1875 - input_bottom_padding_overlap * jcp.stride_h);
1876 jmp(bottom_padding_end_label, T_NEAR);
1877
1878 L(bottom_padding_label);
1879 sub(reg_kh, jcp.stride_h);
1880 cmp(reg_kh, 0);
1881 jle(loop_end_label, T_NEAR);
1882
1883 L(bottom_padding_end_label);
1884 }
1885
1886 /* Compute middle block */
1887 add(reg_tmp_input, input_shift * jcp.stride_h);
1888
1889 /* Execute common block and loop */
1890 L(common_block_label);
1891 add(reg_tmp_output, output_shift);
1892 inc(reg_oh);
1893 cmp(reg_oh, reg_oh_worksize);
1894 jl(loop_begin_label, T_NEAR);
1895
1896 L(loop_end_label);
1897}
1898
1899template <cpu_isa_t isa>
1900void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::calculate_w_unrolling(
1901 int &unroll_trips, int &unroll_w, int &unroll_w_tail) {
1902
1903 const bool do_unroll_w = jcp.ow > max_unroll_w_;
1904 if (do_unroll_w) {
1905 unroll_w = nstl::min(block_size_, jcp.ow);
1906 unroll_trips = jcp.ow / unroll_w;
1907 /* calculate tail */
1908 unroll_w_tail = jcp.ow % unroll_w;
1909 /* Perform some rebalancing if tail too small*/
1910 if ((unroll_w_tail == 0 && jcp.r_pad != 0)
1911 || (jcp.r_pad > 0 && jcp.r_pad >= unroll_w_tail)) {
1912 if (unroll_trips > 1) {
1913 unroll_w_tail += unroll_w;
1914 unroll_trips--;
1915 } else {
1916 /* Idealy, this case shouldn't happen */
1917 unroll_w_tail += (unroll_w - unroll_w / 2);
1918 unroll_w = unroll_w / 2;
1919 }
1920 }
1921 } else {
1922 unroll_w_tail = jcp.ow;
1923 }
1924}
1925
1926template <cpu_isa_t isa>
1927inline void
1928jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
1929
1930 Label ow_blk_label; // for computing 'ow middle' block
1931 int pad_offset = 0;
1932 int l_pad = jcp.l_pad;
1933
1934 int unroll_w_tail = 0;
1935 int unroll_w = 0;
1936 int unroll_trips = 0;
1937 calculate_w_unrolling(unroll_trips, unroll_w, unroll_w_tail);
1938
1939 const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block;
1940 const size_t data_offset = unroll_w * ch_step * sizeof(float);
1941
1942 if (jcp.with_bias) compute_bias();
1943
1944 /* Pass filter address, then offset for h_padding. */
1945 deploy_zero_filter();
1946 mov(reg_kh_offset, ptr[this->param1 + GET_OFF(filter_pad_off)]);
1947 add(reg_filter_baddr, reg_kh_offset);
1948
1949 /* compute left padded block */
1950 const bool do_unroll_w = jcp.ow > max_unroll_w_;
1951 if (l_pad && do_unroll_w) {
1952 compute_h_loop(unroll_w, l_pad, 0, 0);
1953 add(reg_output_baddr, data_offset);
1954 add(reg_input_baddr, data_offset * jcp.stride_w);
1955 unroll_trips--;
1956 pad_offset = l_pad;
1957 l_pad = 0;
1958 }
1959
1960 /* Insert loop for 'ow' block when middle block needs to execute more
1961 * than once */
1962 bool do_ow_blk_loop = unroll_trips > 1;
1963 if (do_ow_blk_loop) {
1964 mov(reg_iter_ow_blk, unroll_trips);
1965 L(ow_blk_label);
1966 }
1967 if (unroll_trips > 0) {
1968 compute_h_loop(unroll_w, l_pad, pad_offset, 0);
1969 add(reg_output_baddr, data_offset);
1970 add(reg_input_baddr, data_offset * jcp.stride_w);
1971 }
1972 if (do_ow_blk_loop) {
1973 dec(reg_iter_ow_blk);
1974 cmp(reg_iter_ow_blk, 0);
1975 jg(ow_blk_label, T_NEAR);
1976 }
1977
1978 /* compute right padded block */
1979 if (unroll_w_tail) {
1980 compute_h_loop(
1981 unroll_w_tail, l_pad, pad_offset, jcp.ow - unroll_w_tail);
1982 }
1983}
1984
1985template <cpu_isa_t isa>
1986void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() {
1987 assert(is_src_layout_nxc() == is_ddst_layout_nxc());
1988
1989 preamble();
1990
1991 mov(reg_input_baddr, ptr[this->param1 + GET_OFF(input)]);
1992 mov(reg_output_baddr, ptr[this->param1 + GET_OFF(output)]);
1993 mov(reg_filter_baddr, ptr[this->param1 + GET_OFF(filter)]);
1994
1995 bool set_kmask = isa > avx2 && jcp.ch_tail > 0
1996 && (jcp.with_bias || is_layout_nxc());
1997 if (set_kmask) {
1998 // Prepare masks for tail
1999 Reg32 reg_tmp_32 = reg_tmp.cvt32();
2000 mov(reg_tmp_32, (1 << jcp.ch_tail) - 1);
2001 kmovw(k_ch_tail_mask, reg_tmp_32);
2002 }
2003
2004 compute_ow_block_unroll();
2005
2006 this->postamble();
2007}
2008#undef GET_OFF
2009
2010template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_core>;
2011template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>;
2012template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>;
2013
2014} // namespace x64
2015} // namespace cpu
2016} // namespace impl
2017} // namespace dnnl
2018