1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/memory.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/x64/jit_uni_dw_conv_kernel_f32.hpp" |
24 | |
25 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | using namespace dnnl::impl::prop_kind; |
33 | using namespace dnnl::impl::memory_tracking::names; |
34 | using namespace dnnl::impl::utils; |
35 | |
36 | using namespace Xbyak; |
37 | |
38 | template <cpu_isa_t isa> |
39 | jit_uni_dw_conv_fwd_kernel_f32<isa>::jit_uni_dw_conv_fwd_kernel_f32( |
40 | const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md) |
41 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa), jcp(ajcp) { |
42 | if (jcp.with_eltwise || jcp.with_binary) { |
43 | using namespace binary_injector; |
44 | static constexpr bool preserve_gpr = true; |
45 | static constexpr bool preserve_vmm = false; |
46 | static constexpr size_t helper_vmm_idx = 31; |
47 | static constexpr bool use_exact_tail_scalar_bcast = true; |
48 | const size_t tail_size = jcp.oc_without_padding |
49 | % (cpu_isa_traits<isa>::vlen / sizeof(float)); |
50 | rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r14, r15, |
51 | r12, preserve_gpr, preserve_vmm, |
52 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
53 | memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask, |
54 | use_exact_tail_scalar_bcast}; |
55 | static_params_t static_params {this->param1, rhs_arg_static_params}; |
56 | |
57 | postops_injector_ |
58 | = utils::make_unique<injector::jit_uni_postops_injector_t<isa>>( |
59 | this, jcp.post_ops, static_params); |
60 | } |
61 | } |
62 | |
63 | bool check_if_tail_load(const bool is_ch_tail, const int c_tail, const int ch, |
64 | const int ur_ch_blocks, const int vlen, const int i) { |
65 | return is_ch_tail && (ch + 1 == ur_ch_blocks) && ((i + 1) * vlen > c_tail); |
66 | } |
67 | |
68 | template <cpu_isa_t isa> |
69 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src( |
70 | int ur_ch_blocks, int ur_w, bool is_ch_tail) { |
71 | |
72 | const auto dst_layout_nxc = is_dst_layout_nxc(); |
73 | const auto ch_blk = jcp.ch_block; |
74 | const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk; |
75 | const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; |
76 | const int vlen = cpu_isa_traits<isa>::vlen / sizeof(float); |
77 | const int c_tail = jcp.oc % jcp.ch_block; |
78 | |
79 | const int repeats = max_repeats(); |
80 | for (int i = 0; i < repeats; i++) { |
81 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
82 | const bool is_tail_load = check_if_tail_load( |
83 | is_ch_tail, c_tail, ch, ur_ch_blocks, vlen, i); |
84 | if ((ch + 1 == ur_ch_blocks) && is_ch_tail && c_tail <= i * vlen) |
85 | continue; |
86 | for (int ow = 0; ow < ur_w; ow++) { |
87 | Vmm vmm_acc |
88 | = get_acc_reg(i * ur_ch_blocks * ur_w + ch * ur_w + ow); |
89 | |
90 | const int b_off = ch * ch_blk + i * vlen; |
91 | if (this->jcp.with_bias) { |
92 | if (is_tail_load) { |
93 | load_tail(vmm_acc, reg_bias, b_off * sizeof(float), |
94 | (c_tail - i * vlen) * sizeof(float)); |
95 | } else { |
96 | uni_vmovups(vmm_acc, |
97 | vmmword[reg_bias + b_off * sizeof(float)]); |
98 | } |
99 | } else { |
100 | uni_vpxor(vmm_acc, vmm_acc, vmm_acc); |
101 | } |
102 | |
103 | const int o_off = ch * ocb_stride + ow * ow_stride + i * vlen; |
104 | if (this->jcp.with_sum) { |
105 | if (is_tail_load) { |
106 | if (this->jcp.with_bias) { |
107 | // using ker_vmm as vmm_tmp as it is safe to do so. |
108 | auto vmm_tmp = get_ker_reg(0); |
109 | add_tail_from_mem(vmm_acc, vmm_tmp, reg_output, |
110 | o_off * sizeof(float), |
111 | (c_tail - i * vlen) * sizeof(float)); |
112 | } else { |
113 | // nothing to add, just load dst. |
114 | load_tail(vmm_acc, reg_output, |
115 | o_off * sizeof(float), |
116 | c_tail * sizeof(float)); |
117 | } |
118 | } else { |
119 | // blocked layout has dst padded, so no tail handling. |
120 | uni_vaddps(vmm_acc, vmm_acc, |
121 | vmmword[reg_output + o_off * sizeof(float)]); |
122 | } |
123 | } |
124 | } |
125 | } |
126 | } |
127 | } |
128 | |
129 | template <cpu_isa_t isa> |
130 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled( |
131 | int ur_ch_blocks, int ur_w, int pad_l, int pad_r, bool is_ch_tail) { |
132 | int ch_blk = jcp.ch_block; |
133 | int dilate_h = jcp.dilate_h + 1; |
134 | int dilate_w = jcp.dilate_w + 1; |
135 | int stride_w = jcp.stride_w; |
136 | |
137 | const auto src_layout_nxc = is_src_layout_nxc(); |
138 | const auto iw_stride = src_layout_nxc ? jcp.ngroups : ch_blk; |
139 | const auto ih_stride = jcp.iw * iw_stride; |
140 | const auto icb_stride = src_layout_nxc |
141 | ? ch_blk |
142 | : (jcp.is_fused_conv ? 1 : jcp.ih) * jcp.iw * ch_blk; |
143 | const int vlen = cpu_isa_traits<isa>::vlen / sizeof(float); |
144 | |
145 | auto get_input_spatial_index = [=](int oi, int ki) { |
146 | return (ki * dilate_w + oi * stride_w - pad_l); |
147 | }; |
148 | |
149 | auto get_input_offset = [=](int ii, int ci, int rep) { |
150 | return (ci * icb_stride + ii * iw_stride + rep * vlen) |
151 | * jcp.typesize_in; |
152 | }; |
153 | |
154 | int ii_start = 0; |
155 | int ii_end = -1; |
156 | if (jcp.is_resrc_depthwise) { |
157 | // find bounds of input spatial indices |
158 | bool first = true; |
159 | for (int ki = 0; ki < jcp.kw; ki++) { |
160 | int oi_start = get_ow_start(ki, pad_l); |
161 | int oi_end = get_ow_end(ur_w, ki, pad_r); |
162 | for (int oi = oi_start; oi < oi_end; oi++) { |
163 | int ii = get_input_spatial_index(oi, ki); |
164 | if (first || ii < ii_start) ii_start = ii; |
165 | if (first || ii > ii_end) ii_end = ii; |
166 | first = false; |
167 | } |
168 | } |
169 | } |
170 | |
171 | Label iter_exit_label; |
172 | |
173 | cmp(reg_kh, 0); |
174 | je(iter_exit_label, T_NEAR); |
175 | |
176 | mov(iter_kh, reg_kh); |
177 | Label kh_label; |
178 | L(kh_label); |
179 | { |
180 | if (jcp.is_fused_conv) { |
181 | mov(aux_reg_input, ptr[aux_reg_input_buffer_ptr]); |
182 | add(aux_reg_input, reg_iw_offset); |
183 | } |
184 | const int c_tail = jcp.oc % jcp.ch_block; |
185 | const int repeats = max_repeats(); |
186 | for (int i = 0; i < repeats; i++) { |
187 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
188 | const bool is_tail_load = check_if_tail_load( |
189 | is_ch_tail, c_tail, ch, ur_ch_blocks, vlen, i); |
190 | if ((ch + 1 == ur_ch_blocks) && is_ch_tail |
191 | && c_tail <= i * vlen) |
192 | continue; |
193 | if (jcp.is_resrc_depthwise) { |
194 | // now we can load input once and reuse up to jcp.kw times |
195 | for (int ii = ii_start; ii <= ii_end; ii++) { |
196 | Vmm vmm_src = get_src_reg(ii); |
197 | const int inp_off = get_input_offset(ii, ch, i); |
198 | if (is_tail_load) { |
199 | load_tail(vmm_src, aux_reg_input, inp_off, |
200 | (c_tail - i * vlen) * jcp.typesize_in); |
201 | } else { |
202 | uni_vmovups(vmm_src, ptr[aux_reg_input + inp_off]); |
203 | } |
204 | } |
205 | } |
206 | for (int kw = 0; kw < jcp.kw; kw++) { |
207 | const int ker_off = ch * jcp.kh * jcp.kw * ch_blk |
208 | + kw * ch_blk + i * vlen; |
209 | |
210 | Vmm vmm_ker = get_ker_reg(0); |
211 | uni_vmovups(vmm_ker, |
212 | ptr[aux_reg_kernel + ker_off * sizeof(float)]); |
213 | |
214 | int ow_start = get_ow_start(kw, pad_l); |
215 | int ow_end = get_ow_end(ur_w, kw, pad_r); |
216 | for (int ow = ow_start; ow < ow_end; ow++) { |
217 | |
218 | const int ii = get_input_spatial_index(ow, kw); |
219 | Vmm vmm_src = jcp.is_resrc_depthwise ? get_src_reg(ii) |
220 | : get_src_reg(0); |
221 | if (!jcp.is_resrc_depthwise) { |
222 | const int inp_off = get_input_offset(ii, ch, i); |
223 | if (is_tail_load) { |
224 | load_tail(vmm_src, aux_reg_input, inp_off, |
225 | (c_tail - i * vlen) * jcp.typesize_in); |
226 | } else { |
227 | uni_vmovups( |
228 | vmm_src, ptr[aux_reg_input + inp_off]); |
229 | } |
230 | } |
231 | Vmm vmm_acc = get_acc_reg( |
232 | i * ur_ch_blocks * ur_w + ch * ur_w + ow); |
233 | uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); |
234 | } |
235 | } |
236 | } |
237 | } |
238 | |
239 | add(aux_reg_kernel, jcp.kw * ch_blk * sizeof(float)); |
240 | if (jcp.is_fused_conv) { |
241 | // Move to next row pointer in the buffer |
242 | add(aux_reg_input_buffer_ptr, sizeof(void *)); |
243 | } else { |
244 | add(aux_reg_input, ih_stride * dilate_h * sizeof(float)); |
245 | } |
246 | |
247 | dec(iter_kh); |
248 | cmp(iter_kh, 0); |
249 | jg(kh_label, T_NEAR); |
250 | } |
251 | |
252 | L(iter_exit_label); |
253 | } |
254 | |
255 | template <typename F> |
256 | void iterate(const int repeats, const int ur_ch_blocks, const int ur_w, |
257 | const bool mask_tail, const F &f) { |
258 | for (int r = 0; r < repeats; r++) |
259 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
260 | const bool mask_flag = mask_tail && ch + 1 == ur_ch_blocks; |
261 | for (int ow = 0; ow < ur_w; ow++) |
262 | f(r, ch, ow, mask_flag); |
263 | } |
264 | } |
265 | |
266 | template <typename F> |
267 | void iterate( |
268 | const int repeats, const int ur_ch_blocks, const int ur_w, const F &f) { |
269 | iterate(repeats, ur_ch_blocks, ur_w, false, f); |
270 | } |
271 | |
272 | template <cpu_isa_t isa> |
273 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_postops( |
274 | const int ur_ch_blocks, const int ur_w, const bool is_ch_tail) { |
275 | if (this->jcp.with_eltwise || this->jcp.with_binary) { |
276 | const int repeats = max_repeats(); |
277 | injector_utils::vmm_index_set_t vmm_idxs; |
278 | if (jcp.with_binary) { |
279 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, |
280 | rhs_arg_params_tail; |
281 | const auto dst_layout_nxc = is_dst_layout_nxc(); |
282 | const auto ch_blk = jcp.ch_block; |
283 | const auto ocb_stride |
284 | = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk; |
285 | const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; |
286 | const auto mask_tail_blocked_layout |
287 | = jcp.oc_without_padding % jcp.ch_block && !dst_layout_nxc; |
288 | const int c_tail = jcp.oc_without_padding % jcp.ch_block; |
289 | iterate(repeats, ur_ch_blocks, ur_w, mask_tail_blocked_layout, |
290 | [&](const int r, const int ch, const int ow, |
291 | const bool mask_flag_blocked_layout) { |
292 | const int vlen |
293 | = cpu_isa_traits<isa>::vlen / sizeof(float); |
294 | const bool is_tail_load = check_if_tail_load( |
295 | is_ch_tail, c_tail, ch, ur_ch_blocks, vlen, r); |
296 | if ((ch + 1 == ur_ch_blocks) && is_ch_tail |
297 | && c_tail <= r * vlen) |
298 | return; |
299 | const size_t o_off = jcp.typesize_out |
300 | * (ch * ocb_stride + ow * ow_stride + r * vlen); |
301 | const auto vmm_idx = get_acc_reg_idx( |
302 | r * ur_ch_blocks * ur_w + ch * ur_w + ow); |
303 | vmm_idxs.emplace(vmm_idx); |
304 | |
305 | rhs_arg_params_tail.vmm_idx_to_out_reg.emplace( |
306 | vmm_idx, reg_output); |
307 | rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace( |
308 | vmm_idx, o_off); |
309 | if (mask_flag_blocked_layout || is_tail_load) |
310 | rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx); |
311 | }); |
312 | rhs_arg_params = rhs_arg_params_tail; |
313 | rhs_arg_params.vmm_tail_idx_.clear(); |
314 | |
315 | Label postops_done; |
316 | if (mask_tail_blocked_layout) { |
317 | // mask_tail_blocked_layout approach of dynamic tail handling is |
318 | // used in blocked layout only. TODO: may be unify? |
319 | Label postops_no_tail; |
320 | mov(reg_tmp, ptr[param1 + GET_OFF(load_work)]); |
321 | cmp(reg_tmp, jcp.nb_ch_blocking * jcp.ch_block); |
322 | jge(postops_no_tail, T_NEAR); |
323 | postops_injector_->compute_vector_range( |
324 | vmm_idxs, rhs_arg_params_tail); |
325 | jmp(postops_done, T_NEAR); |
326 | L(postops_no_tail); |
327 | } else if (is_ch_tail) { |
328 | postops_injector_->compute_vector_range( |
329 | vmm_idxs, rhs_arg_params_tail); |
330 | } |
331 | if (!is_ch_tail) { |
332 | postops_injector_->compute_vector_range( |
333 | vmm_idxs, rhs_arg_params); |
334 | L(postops_done); |
335 | } |
336 | } else { |
337 | iterate(repeats, ur_ch_blocks, ur_w, |
338 | [&](const int r, const int ch, const int ow, const bool) { |
339 | vmm_idxs.emplace(get_acc_reg_idx( |
340 | r * ur_ch_blocks * ur_w + ch * ur_w + ow)); |
341 | }); |
342 | postops_injector_->compute_vector_range(vmm_idxs); |
343 | } |
344 | } |
345 | } |
346 | |
347 | template <cpu_isa_t isa> |
348 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_tail( |
349 | Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int load_size) { |
350 | uni_vmovups(vmm | k_oc_tail_mask | T_z, ptr[reg + offset]); |
351 | } |
352 | |
353 | template <> |
354 | void jit_uni_dw_conv_fwd_kernel_f32<avx2>::load_tail( |
355 | Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int load_size) { |
356 | load_bytes(vmm, reg, offset, load_size); |
357 | } |
358 | |
359 | template <> |
360 | void jit_uni_dw_conv_fwd_kernel_f32<sse41>::load_tail( |
361 | Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int load_size) { |
362 | load_bytes(vmm, reg, offset, load_size); |
363 | } |
364 | |
365 | template <cpu_isa_t isa> |
366 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::add_tail_from_mem(Vmm &vmm_acc, |
367 | Vmm &vmm_tmp, const Xbyak::Reg64 ®, int64_t offset, int load_size) { |
368 | uni_vaddps(vmm_acc | k_oc_tail_mask | T_z, vmm_acc, ptr[reg + offset]); |
369 | } |
370 | |
371 | template <> |
372 | void jit_uni_dw_conv_fwd_kernel_f32<avx2>::add_tail_from_mem(Vmm &vmm_acc, |
373 | Vmm &vmm_tmp, const Xbyak::Reg64 ®, int64_t offset, int load_size) { |
374 | load_bytes(vmm_tmp, reg, offset, load_size); |
375 | uni_vaddps(vmm_acc, vmm_acc, vmm_tmp); |
376 | } |
377 | |
378 | template <> |
379 | void jit_uni_dw_conv_fwd_kernel_f32<sse41>::add_tail_from_mem(Vmm &vmm_acc, |
380 | Vmm &vmm_tmp, const Xbyak::Reg64 ®, int64_t offset, int load_size) { |
381 | load_bytes(vmm_tmp, reg, offset, load_size); |
382 | uni_vaddps(vmm_acc, vmm_acc, vmm_tmp); |
383 | } |
384 | |
385 | template <cpu_isa_t isa> |
386 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_tail( |
387 | Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int store_size) { |
388 | uni_vmovups(vmmword[reg + offset], vmm | k_oc_tail_mask); |
389 | } |
390 | |
391 | template <> |
392 | void jit_uni_dw_conv_fwd_kernel_f32<avx2>::store_tail( |
393 | Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int store_size) { |
394 | store_bytes(vmm, reg, offset, store_size); |
395 | } |
396 | |
397 | template <> |
398 | void jit_uni_dw_conv_fwd_kernel_f32<sse41>::store_tail( |
399 | Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int store_size) { |
400 | store_bytes(vmm, reg, offset, store_size); |
401 | } |
402 | |
403 | template <cpu_isa_t isa> |
404 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst( |
405 | int ur_ch_blocks, int ur_w, bool is_ch_tail) { |
406 | |
407 | const auto dst_layout_nxc = is_dst_layout_nxc(); |
408 | const auto ch_blk = jcp.ch_block; |
409 | const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk; |
410 | const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; |
411 | const int vlen = cpu_isa_traits<isa>::vlen / sizeof(float); |
412 | const int c_tail = jcp.oc_without_padding % jcp.ch_block; |
413 | |
414 | const int repeats = max_repeats(); |
415 | for (int i = 0; i < repeats; i++) { |
416 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
417 | const bool is_tail_load = check_if_tail_load( |
418 | is_ch_tail, c_tail, ch, ur_ch_blocks, vlen, i); |
419 | if ((ch + 1 == ur_ch_blocks) && is_ch_tail && c_tail <= i * vlen) |
420 | continue; |
421 | for (int ow = 0; ow < ur_w; ow++) { |
422 | const int o_off = ch * ocb_stride + ow * ow_stride + i * vlen; |
423 | Vmm vmm_dst |
424 | = get_acc_reg(i * ur_ch_blocks * ur_w + ch * ur_w + ow); |
425 | if (is_tail_load) { |
426 | store_tail(vmm_dst, reg_output, o_off * sizeof(float), |
427 | (c_tail - i * vlen) * sizeof(float)); |
428 | } else |
429 | uni_vmovups(vmmword[reg_output + o_off * sizeof(float)], |
430 | vmm_dst); |
431 | } |
432 | } |
433 | } |
434 | } |
435 | |
436 | template <cpu_isa_t isa> |
437 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::compute_loop( |
438 | int ur_w, int ur_ch_blocks, int pad_l, int pad_r) { |
439 | |
440 | const bool ch_loop = ur_ch_blocks > jcp.nb_ch_blocking; |
441 | // ch_loop currently happen only when data layout is nxc. The strides are |
442 | // calculated for this layout only. |
443 | const size_t wei_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.kh * jcp.kw |
444 | * jcp.ch_block * jcp.typesize_in; |
445 | const size_t inp_ch_stride |
446 | = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_in; |
447 | const size_t out_ch_stride |
448 | = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_out; |
449 | const size_t bias_stride |
450 | = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float); |
451 | |
452 | auto compute = [&](int ur_ch_blocks, bool is_ch_tail) { |
453 | if (jcp.is_fused_conv) { |
454 | mov(aux_reg_input_buffer_ptr, reg_input_buffer_ptr); |
455 | } else { |
456 | mov(aux_reg_input, reg_input); |
457 | } |
458 | |
459 | mov(aux_reg_kernel, reg_kernel); |
460 | load_src(ur_ch_blocks, ur_w, is_ch_tail); |
461 | apply_filter_unrolled(ur_ch_blocks, ur_w, pad_l, pad_r, is_ch_tail); |
462 | apply_postops(ur_ch_blocks, ur_w, is_ch_tail); |
463 | store_dst(ur_ch_blocks, ur_w, is_ch_tail); |
464 | }; |
465 | |
466 | mov(aux_reg_ch_blocks, reg_ch_blocks); |
467 | if (ch_loop) { |
468 | Label ch_loop_label, ch_tail_label, skip_ch_tail_label; |
469 | const int ch_block_tail = jcp.nb_ch |
470 | - (utils::rnd_dn(jcp.oc / jcp.ch_block, jcp.nb_ch_blocking)); |
471 | const int ch_step = jcp.nb_ch_blocking * jcp.ch_block; |
472 | |
473 | push(reg_kernel); |
474 | push(reg_input); |
475 | push(reg_output); |
476 | if (jcp.with_bias) push(reg_bias); |
477 | |
478 | if ((jcp.oc / jcp.ch_block) >= jcp.nb_ch_blocking) { |
479 | if (ch_block_tail) { |
480 | cmp(aux_reg_ch_blocks, ch_step); |
481 | jl(ch_tail_label, T_NEAR); |
482 | } |
483 | |
484 | L(ch_loop_label); |
485 | { |
486 | compute(jcp.nb_ch_blocking, false); |
487 | add(reg_kernel, wei_ch_stride); |
488 | add(reg_input, inp_ch_stride); |
489 | add(reg_output, out_ch_stride); |
490 | if (jcp.with_bias) add(reg_bias, bias_stride); |
491 | sub(aux_reg_ch_blocks, ch_step); |
492 | cmp(aux_reg_ch_blocks, ch_step); |
493 | jge(ch_loop_label, T_NEAR); |
494 | } |
495 | } |
496 | |
497 | if (ch_block_tail) { |
498 | // ch work range [1, jcp.nb_ch_blocking * ch_block) |
499 | L(ch_tail_label); |
500 | cmp(aux_reg_ch_blocks, 0); |
501 | jle(skip_ch_tail_label, T_NEAR); |
502 | compute(ch_block_tail, jcp.oc % jcp.ch_block); |
503 | L(skip_ch_tail_label); |
504 | } |
505 | |
506 | if (jcp.with_bias) pop(reg_bias); |
507 | pop(reg_output); |
508 | pop(reg_input); |
509 | pop(reg_kernel); |
510 | |
511 | } else { |
512 | compute(ur_ch_blocks, jcp.oc % jcp.ch_block); |
513 | } |
514 | } |
515 | |
516 | template <cpu_isa_t isa> |
517 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::ow_loop(int ur_ch_blocks) { |
518 | |
519 | int iw = jcp.iw; |
520 | int ow = jcp.ow; |
521 | int kw = jcp.kw; |
522 | int l_pad = jcp.l_pad; |
523 | int ur_w = jcp.ur_w; |
524 | int ur_w_tail = jcp.ur_w_tail; |
525 | int stride_w = jcp.stride_w; |
526 | |
527 | const auto src_layout_nxc = is_src_layout_nxc(); |
528 | const auto dat_c_stride = src_layout_nxc ? jcp.ngroups : jcp.ch_block; |
529 | size_t inp_shift = (size_t)jcp.typesize_in * ur_w * stride_w * dat_c_stride; |
530 | size_t out_shift = (size_t)jcp.typesize_out * ur_w * dat_c_stride; |
531 | |
532 | int inp_shift_pad |
533 | = jcp.typesize_in * (ur_w * stride_w - l_pad) * dat_c_stride; |
534 | |
535 | int r_pad = nstl::max(0, jcp.r_pad); |
536 | int n_oi = ow / ur_w; |
537 | int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w, |
538 | calculate_extended_filter_size(kw, jcp.dilate_w)); |
539 | |
540 | assert(jcp.nb_ow <= 1); |
541 | |
542 | if (r_pad1 > 0) n_oi--; |
543 | xor_(reg_oi, reg_oi); |
544 | if (ow == ur_w) { |
545 | compute_loop(ur_w, ur_ch_blocks, l_pad, r_pad); |
546 | } else { |
547 | if (n_oi == 0) { |
548 | compute_loop(ur_w, ur_ch_blocks, l_pad, r_pad1); |
549 | add(reg_input, inp_shift_pad); |
550 | add(reg_output, out_shift); |
551 | if (ur_w_tail != 0) { |
552 | compute_loop(ur_w_tail, ur_ch_blocks, 0, r_pad); |
553 | } |
554 | } else { |
555 | if (l_pad > 0) { |
556 | compute_loop(ur_w, ur_ch_blocks, l_pad, 0); |
557 | add(reg_input, inp_shift_pad); |
558 | add(reg_output, out_shift); |
559 | inc(reg_oi); |
560 | } |
561 | if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { |
562 | Label ow_loop_label; |
563 | L(ow_loop_label); |
564 | { |
565 | compute_loop(ur_w, ur_ch_blocks, 0, 0); |
566 | add(reg_input, inp_shift); |
567 | add(reg_output, out_shift); |
568 | |
569 | inc(reg_oi); |
570 | cmp(reg_oi, n_oi); |
571 | jl(ow_loop_label, T_NEAR); |
572 | } |
573 | } |
574 | if (r_pad1 > 0) { |
575 | compute_loop(ur_w, ur_ch_blocks, 0, r_pad1); |
576 | add(reg_input, inp_shift); |
577 | add(reg_output, out_shift); |
578 | } |
579 | if (ur_w_tail != 0) { |
580 | compute_loop(ur_w_tail, ur_ch_blocks, 0, r_pad); |
581 | } |
582 | } |
583 | } |
584 | } |
585 | |
586 | template <cpu_isa_t isa> |
587 | void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() { |
588 | this->preamble(); |
589 | |
590 | if (jcp.is_fused_conv) { |
591 | mov(reg_input_buffer_ptr, ptr[this->param1 + GET_OFF(src)]); |
592 | /* In case of fused depthwise convolution, `param.src` is not a pointer |
593 | to input, instead it points to a buffer containing pointers to |
594 | consecutive rows of input in format Cwc with blocking nb_ch_blocking. |
595 | Example: [ptr_to_inp_row0, ptr_to_inp_row1, ptr_to_inp_row2]. |
596 | Traverse the data as |
597 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
598 | ... process row0 ... |
599 | add(reg_input_buffer_ptr, sizeof(void*)) |
600 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
601 | ... process row1 ... |
602 | add(reg_input_buffer_ptr, sizeof(void*)) |
603 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
604 | ... process row2 ... |
605 | */ |
606 | xor_(reg_iw_offset, reg_iw_offset); |
607 | } else { |
608 | mov(reg_input, ptr[this->param1 + GET_OFF(src)]); |
609 | } |
610 | mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); |
611 | mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); |
612 | if (jcp.with_bias) mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); |
613 | mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); |
614 | mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(load_work)]); |
615 | |
616 | Label ch_blocks_tail_label; |
617 | Label exit_label; |
618 | |
619 | int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; |
620 | if (isa > avx2) { |
621 | const auto oc_tail = jcp.oc_without_padding % jcp.ch_block; |
622 | if (oc_tail != 0) { |
623 | // Prepare masks for tailing |
624 | const int oc_tail_shift |
625 | = jcp.ch_block - jcp.oc_without_padding % jcp.ch_block; |
626 | static constexpr auto zmm_full_mask = ((1 << 16) - 1); |
627 | Reg32 reg_tail_32 = reg_tail.cvt32(); |
628 | mov(reg_tail_32, (zmm_full_mask >> oc_tail_shift)); |
629 | kmovw(k_oc_tail_mask, reg_tail_32); |
630 | } |
631 | } |
632 | |
633 | if (is_src_layout_nxc()) { |
634 | ow_loop(jcp.nb_ch); |
635 | } else { |
636 | cmp(reg_ch_blocks, (jcp.nb_ch_blocking - 1) * jcp.ch_block); |
637 | jle(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); |
638 | |
639 | ow_loop(jcp.nb_ch_blocking); // channel main loop |
640 | |
641 | if (ch_blocks_tail) { |
642 | jmp(exit_label, T_NEAR); |
643 | L(ch_blocks_tail_label); |
644 | ow_loop(ch_blocks_tail); // channel tail loop |
645 | } |
646 | |
647 | L(exit_label); |
648 | } |
649 | |
650 | this->postamble(); |
651 | |
652 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
653 | } |
654 | |
655 | template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_core>; |
656 | template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>; |
657 | template struct jit_uni_dw_conv_fwd_kernel_f32<sse41>; |
658 | |
659 | template <cpu_isa_t isa> |
660 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_vmm( |
661 | Vmm &vmm, const Xbyak::Address &addr, bool tail) { |
662 | int ch_tail = jcp.oc_without_padding % simd_w_; // special case for SSE41 |
663 | int bytes = (tail && ch_tail > 0 ? ch_tail : simd_w_) * sizeof(float); |
664 | load_bytes(vmm, addr, bytes); |
665 | } |
666 | template <> |
667 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<avx2>::load_vmm( |
668 | Vmm &vmm, const Xbyak::Address &addr, bool tail) { |
669 | int bytes = (tail ? jcp.ch_tail : jcp.ch_block) * sizeof(float); |
670 | load_bytes(vmm, addr, bytes); |
671 | } |
672 | template <> |
673 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<avx512_core>::load_vmm( |
674 | Vmm &vmm, const Xbyak::Address &addr, bool tail) { |
675 | Zmm masked_vmm = tail ? vmm | k_ch_tail_mask | T_z : vmm; |
676 | vmovups(masked_vmm, addr); |
677 | } |
678 | |
679 | template <cpu_isa_t isa> |
680 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_vmm( |
681 | Vmm &vmm, const Xbyak::Address &addr, bool tail) { |
682 | int ch_tail = jcp.oc_without_padding % simd_w_; // special case for SSE41 |
683 | int bytes = (tail && ch_tail > 0 ? ch_tail : simd_w_) * sizeof(float); |
684 | store_bytes(vmm, addr, bytes); |
685 | } |
686 | template <> |
687 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<avx2>::store_vmm( |
688 | Vmm &vmm, const Xbyak::Address &addr, bool tail) { |
689 | int bytes = (tail ? jcp.ch_tail : jcp.ch_block) * sizeof(float); |
690 | store_bytes(vmm, addr, bytes); |
691 | } |
692 | template <> |
693 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<avx512_core>::store_vmm( |
694 | Vmm &vmm, const Xbyak::Address &addr, bool tail) { |
695 | Zmm masked_vmm = tail ? vmm | k_ch_tail_mask : vmm; |
696 | vmovups(addr, masked_vmm); |
697 | } |
698 | |
699 | template <cpu_isa_t isa> |
700 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst( |
701 | int ur_ch_blocks, int ur_str_w) { |
702 | for (int i = 0; i < reg_repeats_; i++) { |
703 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
704 | for (int w = 0; w < ur_str_w; w++) { |
705 | Vmm vmm_acc = get_acc_reg( |
706 | i * ur_ch_blocks * ur_str_w + ch * ur_str_w + w); |
707 | uni_vpxor(vmm_acc, vmm_acc, vmm_acc); |
708 | } |
709 | } |
710 | } |
711 | } |
712 | |
713 | template <cpu_isa_t isa> |
714 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter( |
715 | int ur_ch_blocks, int ur_str_w, bool is_last_ch) { |
716 | int kw = jcp.kw; |
717 | int kh = jcp.kh; |
718 | int ow = jcp.ow; |
719 | int oh = jcp.oh; |
720 | |
721 | int ch_blk = jcp.ch_block; |
722 | int stride_h = jcp.stride_h; |
723 | int stride_w = jcp.stride_w; |
724 | |
725 | const bool ddst_layout_nxc = is_ddst_layout_nxc(); |
726 | const size_t ch_block_step = ch_blk * (ddst_layout_nxc ? 1 : oh * ow); |
727 | const size_t sp_step = ddst_layout_nxc ? jcp.ngroups : ch_blk; |
728 | |
729 | Label iter_exit_label; |
730 | |
731 | cmp(reg_kh, 0); |
732 | je(iter_exit_label, T_NEAR); |
733 | |
734 | cmp(reg_kw, 0); |
735 | je(iter_exit_label, T_NEAR); |
736 | |
737 | mov(iter_kh, reg_kh); |
738 | Label kh_label; |
739 | L(kh_label); |
740 | { |
741 | mov(aux1_reg_ddst, aux_reg_ddst); |
742 | mov(aux1_reg_kernel, aux_reg_kernel); |
743 | |
744 | mov(iter_kw, reg_kw); |
745 | Label kw_label; |
746 | L(kw_label); |
747 | { |
748 | for (int r = 0; r < reg_repeats_; r++) { |
749 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
750 | bool last_block = is_last_ch && ch == ur_ch_blocks - 1; |
751 | bool masked_load = last_block |
752 | && IMPLICATION( |
753 | isa == sse41, tail_simd_overlap(r + 1)); |
754 | |
755 | // sse41: if second simd_w is outside channel_block, skip |
756 | if (last_block && isa == sse41 && tail_simd_overlap(r)) |
757 | break; |
758 | |
759 | int ker_off = ch * kh * kw * ch_blk + r * simd_w_; |
760 | Vmm vmm_ker = get_ker_reg(0); |
761 | load_vmm(vmm_ker, |
762 | ptr[aux1_reg_kernel + ker_off * sizeof(float)], |
763 | masked_load); |
764 | |
765 | for (int w = 0; w < ur_str_w; w++) { |
766 | size_t sp_offset = w * sp_step; |
767 | size_t ch_offset = ch * ch_block_step; |
768 | size_t ddst_off = static_cast<size_t>( |
769 | (sp_offset + ch_offset + r * simd_w_) |
770 | * sizeof(float)); |
771 | |
772 | Vmm vmm_ddst = get_ddst_reg(0); |
773 | load_vmm(vmm_ddst, ptr[aux1_reg_ddst + ddst_off], |
774 | masked_load); |
775 | |
776 | Vmm vmm_acc = get_acc_reg(r * ur_ch_blocks * ur_str_w |
777 | + ch * ur_str_w + w); |
778 | uni_vfmadd231ps(vmm_acc, vmm_ddst, vmm_ker); |
779 | } |
780 | } |
781 | } |
782 | |
783 | add(aux1_reg_kernel, ch_blk * stride_w * sizeof(float)); |
784 | sub(aux1_reg_ddst, sp_step * sizeof(float)); |
785 | |
786 | sub(iter_kw, stride_w); |
787 | cmp(iter_kw, 0); |
788 | jg(kw_label, T_NEAR); |
789 | } |
790 | |
791 | add(aux_reg_kernel, kw * ch_blk * stride_h * sizeof(float)); |
792 | sub(aux_reg_ddst, ow * sp_step * sizeof(float)); |
793 | |
794 | sub(iter_kh, stride_h); |
795 | cmp(iter_kh, 0); |
796 | jg(kh_label, T_NEAR); |
797 | } |
798 | |
799 | L(iter_exit_label); |
800 | } |
801 | |
802 | template <cpu_isa_t isa> |
803 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc( |
804 | int ur_ch_blocks, int ur_str_w, bool is_last_ch) { |
805 | int ch_block = jcp.ch_block; |
806 | int iw = jcp.iw; |
807 | int ih = jcp.ih; |
808 | int stride_w = jcp.stride_w; |
809 | |
810 | const auto dsrc_layout_nxc = is_dsrc_layout_nxc(); |
811 | const size_t ch_block_step = ch_block * (dsrc_layout_nxc ? 1 : ih * iw); |
812 | const size_t sp_step |
813 | = dsrc_layout_nxc ? jcp.ngroups : ch_block; // spatial step |
814 | |
815 | for (int r = 0; r < reg_repeats_; r++) { |
816 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
817 | bool last_block = is_last_ch && ch == ur_ch_blocks - 1; |
818 | bool masked_store = last_block |
819 | && IMPLICATION(isa == sse41, tail_simd_overlap(r + 1)); |
820 | |
821 | // sse41: if second simd_w is outside channel_block, skip |
822 | if (last_block && tail_simd_overlap(r)) break; |
823 | |
824 | for (int w = 0; w < ur_str_w; w++) { |
825 | size_t sp_offset = w * stride_w * sp_step; |
826 | size_t ch_offset = ch * ch_block_step + r * simd_w_; |
827 | size_t dsrc_off = static_cast<size_t>( |
828 | (sp_offset + ch_offset) * sizeof(float)); |
829 | |
830 | Vmm vmm_acc |
831 | = get_acc_reg((r * ur_ch_blocks + ch) * ur_str_w + w); |
832 | store_vmm(vmm_acc, ptr[reg_dsrc + dsrc_off], masked_store); |
833 | } |
834 | } |
835 | } |
836 | } |
837 | |
838 | template <cpu_isa_t isa> |
839 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::ch_loop_body( |
840 | int ur_ch_blocks, int unroll_w) { |
841 | |
842 | auto call_compute_body |
843 | = [&](int ur_ch_blocks, int unroll_w, bool is_last_ch = false) { |
844 | mov(aux_reg_ddst, reg_ddst); |
845 | mov(aux_reg_kernel, reg_kernel); |
846 | |
847 | load_ddst(ur_ch_blocks, unroll_w); |
848 | apply_filter(ur_ch_blocks, unroll_w, is_last_ch); |
849 | store_dsrc(ur_ch_blocks, unroll_w, is_last_ch); |
850 | }; |
851 | |
852 | const bool write_ch_loop = ur_ch_blocks > jcp.nb_ch_blocking; |
853 | if (write_ch_loop) { |
854 | assert(is_ddst_layout_nxc()); |
855 | |
856 | Label ch_loop_label, ch_tail_label, skip_ch_tail_label; |
857 | const int nb_oc = jcp.oc / jcp.ch_block; |
858 | const int ch_block_tail |
859 | = jcp.nb_ch - (utils::rnd_dn(nb_oc, jcp.nb_ch_blocking)); |
860 | const int ch_step = jcp.nb_ch_blocking * jcp.ch_block; |
861 | |
862 | const size_t wei_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.kh |
863 | * jcp.kw * jcp.ch_block * sizeof(float); |
864 | const size_t data_ch_stride |
865 | = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float); |
866 | |
867 | mov(aux_reg_ch_blocks, reg_ch_blocks); |
868 | push(reg_dsrc); |
869 | push(reg_ddst); |
870 | push(reg_kernel); |
871 | |
872 | if (nb_oc >= jcp.nb_ch_blocking) { |
873 | if (ch_block_tail) { |
874 | cmp(aux_reg_ch_blocks, jcp.nb_ch_blocking * jcp.ch_block); |
875 | jl(ch_tail_label, T_NEAR); |
876 | } |
877 | |
878 | L(ch_loop_label); |
879 | { |
880 | call_compute_body(jcp.nb_ch_blocking, unroll_w); |
881 | |
882 | add(reg_kernel, wei_ch_stride); |
883 | add(reg_dsrc, data_ch_stride); |
884 | add(reg_ddst, data_ch_stride); |
885 | |
886 | sub(aux_reg_ch_blocks, ch_step); |
887 | cmp(aux_reg_ch_blocks, ch_step); |
888 | jge(ch_loop_label, T_NEAR); |
889 | } |
890 | } |
891 | |
892 | if (ch_block_tail) { |
893 | // ch work range [1, jcp.nb_ch_blocking * ch_block) |
894 | L(ch_tail_label); |
895 | cmp(aux_reg_ch_blocks, 0); |
896 | jle(skip_ch_tail_label, T_NEAR); |
897 | call_compute_body(ch_block_tail, unroll_w, jcp.ch_tail > 0); |
898 | L(skip_ch_tail_label); |
899 | } |
900 | |
901 | pop(reg_kernel); |
902 | pop(reg_ddst); |
903 | pop(reg_dsrc); |
904 | |
905 | } else { |
906 | call_compute_body(ur_ch_blocks, unroll_w, jcp.ch_tail > 0); |
907 | } |
908 | } |
909 | |
910 | template <cpu_isa_t isa> |
911 | inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::unroll_width_body( |
912 | int ur_ch_blocks) { |
913 | assert(is_dsrc_layout_nxc() == is_ddst_layout_nxc()); |
914 | const size_t ch_step = sizeof(float) |
915 | * (is_ddst_layout_nxc() ? jcp.ngroups : jcp.ch_block); |
916 | |
917 | auto unroll_width_loop = [&](int unroll_w) { |
918 | Label unroll_w_label, skip_compute_label; |
919 | L(unroll_w_label); |
920 | { |
921 | cmp(reg_ur_str_w, unroll_w); |
922 | jl(skip_compute_label, T_NEAR); |
923 | |
924 | ch_loop_body(ur_ch_blocks, unroll_w); |
925 | |
926 | add(reg_dsrc, unroll_w * jcp.stride_w * ch_step); |
927 | add(reg_ddst, unroll_w * ch_step); |
928 | |
929 | sub(reg_ur_str_w, unroll_w); |
930 | jmp(unroll_w_label); |
931 | } |
932 | L(skip_compute_label); |
933 | }; |
934 | |
935 | unroll_width_loop(jcp.ur_w); |
936 | |
937 | unroll_width_loop(1); |
938 | } |
939 | |
940 | template <cpu_isa_t isa> |
941 | void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() { |
942 | preamble(); |
943 | |
944 | mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); |
945 | mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); |
946 | mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); |
947 | mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); |
948 | mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); |
949 | mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); |
950 | mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]); |
951 | |
952 | if (is_dsrc_layout_nxc()) { |
953 | if (isa == avx512_core && (jcp.ch_tail > 0)) { |
954 | Label masking_done; |
955 | const size_t channel_step = jcp.nb_ch_blocking * jcp.ch_block; |
956 | kxnorw(k_ch_tail_mask, k_ch_tail_mask, |
957 | k_ch_tail_mask); // dummy mask all 1's |
958 | cmp(reg_ch_blocks, channel_step); |
959 | je(masking_done, T_NEAR); |
960 | // Prepare masks for tail |
961 | Reg32 reg_tmp_32 = reg_tmp.cvt32(); |
962 | mov(reg_tmp_32, (1 << jcp.ch_tail) - 1); |
963 | kmovw(k_ch_tail_mask, reg_tmp_32); |
964 | L(masking_done); |
965 | } |
966 | |
967 | unroll_width_body(jcp.nb_ch); |
968 | } else { |
969 | |
970 | auto ch_blocks_loop = [&](int ch_blocks) { |
971 | Label skip_loop_label; |
972 | cmp(reg_ch_blocks, ch_blocks * jcp.ch_block); |
973 | jl(skip_loop_label, T_NEAR); |
974 | unroll_width_body(ch_blocks); |
975 | L(skip_loop_label); |
976 | }; |
977 | |
978 | ch_blocks_loop(jcp.nb_ch_blocking); |
979 | |
980 | int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; |
981 | if (ch_blocks_tail) { ch_blocks_loop(ch_blocks_tail); } |
982 | } |
983 | |
984 | this->postamble(); |
985 | } |
986 | #undef GET_OFF |
987 | |
988 | template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_core>; |
989 | template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>; |
990 | template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse41>; |
991 | |
992 | #define GET_OFF(field) offsetof(jit_dw_conv_call_s, field) |
993 | |
994 | template <cpu_isa_t isa> |
995 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_xmm( |
996 | Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) { |
997 | int ch_tail = jcp.oc_without_padding % simd_w_; // special case for SSE41 |
998 | int bytes |
999 | = (compute_tail && ch_tail > 0 ? ch_tail : simd_w_) * sizeof(float); |
1000 | load_bytes(vmm, addr, bytes); |
1001 | } |
1002 | template <> |
1003 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>::load_xmm( |
1004 | Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) { |
1005 | int bytes = (compute_tail ? jcp.ch_tail : jcp.ch_block) * sizeof(float); |
1006 | load_bytes(vmm, addr, bytes); |
1007 | } |
1008 | template <> |
1009 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_core>::load_xmm( |
1010 | Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) { |
1011 | Zmm masked_vmm = compute_tail ? vmm | k_ch_tail_mask | T_z : vmm; |
1012 | vmovups(masked_vmm, addr); |
1013 | } |
1014 | |
1015 | template <cpu_isa_t isa> |
1016 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_xmm( |
1017 | Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) { |
1018 | int ch_tail = jcp.oc_without_padding % simd_w_; // special case for SSE41 |
1019 | int bytes |
1020 | = (compute_tail && ch_tail > 0 ? ch_tail : simd_w_) * sizeof(float); |
1021 | store_bytes(vmm, addr, bytes); |
1022 | } |
1023 | template <> |
1024 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>::store_xmm( |
1025 | Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) { |
1026 | int bytes = (compute_tail ? jcp.ch_tail : jcp.ch_block) * sizeof(float); |
1027 | store_bytes(vmm, addr, bytes); |
1028 | } |
1029 | template <> |
1030 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_core>::store_xmm( |
1031 | Vmm &vmm, const Xbyak::Address &addr, bool compute_tail) { |
1032 | Zmm masked_vmm = compute_tail ? vmm | k_ch_tail_mask : vmm; |
1033 | vmovups(addr, masked_vmm); |
1034 | } |
1035 | |
1036 | template <cpu_isa_t isa> |
1037 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::addps_xmm(Vmm &vmm_dst, |
1038 | Vmm &vmm_src, const Xbyak::Address &addr, bool compute_tail) { |
1039 | load_xmm(vmm_src, addr, compute_tail); |
1040 | uni_vaddps(vmm_dst, vmm_dst, vmm_src); |
1041 | } |
1042 | template <> |
1043 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>::addps_xmm( |
1044 | Vmm &vmm_dst, Vmm &vmm_src, const Xbyak::Address &addr, |
1045 | bool compute_tail) { |
1046 | if (compute_tail) { |
1047 | load_xmm(vmm_src, addr, true); |
1048 | uni_vaddps(vmm_dst, vmm_dst, vmm_src); |
1049 | } else { |
1050 | assert(vmm_dst.getIdx() == vmm_src.getIdx()); |
1051 | uni_vaddps(vmm_dst, vmm_src, addr); |
1052 | } |
1053 | } |
1054 | template <> |
1055 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_core>::addps_xmm( |
1056 | Vmm &vmm_dst, Vmm &vmm_src, const Xbyak::Address &addr, |
1057 | bool compute_tail) { |
1058 | Zmm masked_vmm = compute_tail ? vmm_src | k_ch_tail_mask | T_z : vmm_src; |
1059 | vaddps(vmm_dst, masked_vmm, addr); |
1060 | } |
1061 | |
1062 | template <cpu_isa_t isa> |
1063 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter() { |
1064 | for (int ch = 0; ch < jcp.nb_ch_blocking; ++ch) { |
1065 | for (int r = 0; r < reg_repeats_; ++r) { |
1066 | for (int i = 0; i < jcp.kw; ++i) { |
1067 | Vmm vmm_acc |
1068 | = get_acc_reg(r * jcp.kw + i * jcp.nb_ch_blocking + ch); |
1069 | uni_vpxor(vmm_acc, vmm_acc, vmm_acc); |
1070 | } |
1071 | } |
1072 | } |
1073 | } |
1074 | |
1075 | template <> |
1076 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::load_filter( |
1077 | int nb_ch_blocking, bool is_last_ch) { |
1078 | assert(nb_ch_blocking == 1); |
1079 | for (int r = 0; r < reg_repeats_; ++r) { |
1080 | bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail; |
1081 | bool masked_load = tail_in_first_simd && is_last_ch; |
1082 | const int reg_set = r * jcp.kw; |
1083 | for (int i = 0; i < jcp.kw; ++i) { |
1084 | size_t off_filter = static_cast<size_t>( |
1085 | (i * jcp.ch_block + r * simd_w_) * sizeof(float)); |
1086 | Vmm vmm_acc = get_acc_reg(reg_set + i); |
1087 | load_xmm( |
1088 | vmm_acc, vmmword[reg_tmp_filter + off_filter], masked_load); |
1089 | } |
1090 | if (masked_load) break; // if tail falls under first simd, skip |
1091 | } |
1092 | } |
1093 | |
1094 | template <cpu_isa_t isa> |
1095 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_filter( |
1096 | int nb_ch_blocking, bool is_last_ch) { |
1097 | const size_t filter_step = jcp.kh * jcp.kw; |
1098 | for (int ch = 0; ch < nb_ch_blocking; ++ch) { |
1099 | bool masked_load = is_last_ch && (ch == nb_ch_blocking - 1); |
1100 | for (int i = 0; i < jcp.kw; ++i) { |
1101 | size_t off_filter = static_cast<size_t>( |
1102 | (ch * filter_step + i) * jcp.ch_block * sizeof(float)); |
1103 | Vmm vmm_acc = get_acc_reg(i * jcp.nb_ch_blocking + ch); |
1104 | load_xmm( |
1105 | vmm_acc, vmmword[reg_tmp_filter + off_filter], masked_load); |
1106 | } |
1107 | } |
1108 | } |
1109 | |
1110 | template <cpu_isa_t isa> |
1111 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_bias() { |
1112 | for (int ch = 0; ch < jcp.nb_ch_blocking; ++ch) { |
1113 | for (int r = 0; r < reg_repeats_; ++r) { |
1114 | Vmm vmm_bias = get_bias_reg(r * jcp.nb_ch_blocking + ch); |
1115 | uni_vpxor(vmm_bias, vmm_bias, vmm_bias); |
1116 | } |
1117 | } |
1118 | } |
1119 | |
1120 | template <> |
1121 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::load_bias( |
1122 | int nb_ch_blocking, bool is_last_ch) { |
1123 | for (int r = 0; r < reg_repeats_; ++r) { |
1124 | bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail; |
1125 | bool masked_load = tail_in_first_simd && is_last_ch; |
1126 | size_t half_ch_block_offset |
1127 | = static_cast<size_t>(r * simd_w_ * sizeof(float)); |
1128 | Vmm vmm_bias = get_bias_reg(r); |
1129 | load_xmm(vmm_bias, vmmword[reg_bias_baddr + half_ch_block_offset], |
1130 | masked_load); |
1131 | if (masked_load) break; // if tail falls under first simd, skip |
1132 | } |
1133 | } |
1134 | |
1135 | template <cpu_isa_t isa> |
1136 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_bias( |
1137 | int nb_ch_blocking, bool is_last_ch) { |
1138 | for (int ch = 0; ch < nb_ch_blocking; ++ch) { |
1139 | bool masked_load = is_last_ch && (ch == nb_ch_blocking - 1); |
1140 | size_t bias_offset |
1141 | = static_cast<size_t>(ch * jcp.ch_block * sizeof(float)); |
1142 | Vmm vmm_bias = get_bias_reg(ch); |
1143 | load_xmm(vmm_bias, vmmword[reg_bias_baddr + bias_offset], masked_load); |
1144 | } |
1145 | } |
1146 | |
1147 | template <cpu_isa_t isa> |
1148 | inline void |
1149 | jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_unroll_ow_step_nxc( |
1150 | int unroll_w, int l_pad, int pad_offset, int ow_block, |
1151 | int nb_ch_blocking, bool is_last_ch) { |
1152 | |
1153 | assert(one_of(isa, avx2, avx512_core)); |
1154 | |
1155 | const size_t ch_step = jcp.ngroups; |
1156 | const int iw_block = ow_block * jcp.stride_w; |
1157 | const int right_border = jcp.iw - iw_block; |
1158 | const int r_pad = jcp.r_pad; |
1159 | const int cascade_input = nstl::min(jcp.stride_w, jcp.kw); |
1160 | |
1161 | /* preamble count for number of cascaded LOAD + FMA operation */ |
1162 | const int input_overlap = nstl::max(jcp.kw - l_pad, 0); |
1163 | const bool is_last_block = (unroll_w + ow_block == jcp.ow); |
1164 | |
1165 | /* LOAD initial input registers, then cascade LOADs and FMAs*/ |
1166 | for (int i_ur = 0; i_ur < unroll_w; ++i_ur) { |
1167 | int output_sp_offset = i_ur * ch_step; |
1168 | if (i_ur == 0) { |
1169 | for (int c = 0; c < input_overlap; ++c) { |
1170 | int input_sp = c - pad_offset; |
1171 | int input_sp_offset = input_sp * ch_step; |
1172 | if (input_sp_offset < 0 && unroll_w == jcp.ow) continue; |
1173 | |
1174 | const bool over_steps_bdry = true && is_last_block |
1175 | && (c - pad_offset + r_pad > right_border); |
1176 | if (over_steps_bdry) continue; |
1177 | |
1178 | for (int ch = 0; ch < nb_ch_blocking; ++ch) { |
1179 | bool masked_load = is_last_ch && ch == nb_ch_blocking - 1; |
1180 | size_t input_offset = static_cast<size_t>( |
1181 | (input_sp_offset + ch * simd_w_) * sizeof(float)); |
1182 | Vmm vmm_input = get_input_reg( |
1183 | (c % jcp.kw) * jcp.nb_ch_blocking + ch); |
1184 | load_xmm(vmm_input, ptr[reg_tmp_input + input_offset], |
1185 | masked_load); |
1186 | } |
1187 | } |
1188 | } else { |
1189 | for (int c = 0; c < cascade_input; ++c) { |
1190 | int overlap = (i_ur - 1) * jcp.stride_w + input_overlap; |
1191 | int input_sp = overlap + c - pad_offset; |
1192 | int input_sp_offset = input_sp * ch_step; |
1193 | if (input_sp_offset < 0 || overlap + c + l_pad > right_border) |
1194 | continue; |
1195 | |
1196 | const bool over_steps_bdry = true && is_last_block |
1197 | && (overlap + c - pad_offset + r_pad > right_border); |
1198 | if (over_steps_bdry) continue; |
1199 | |
1200 | for (int ch = 0; ch < nb_ch_blocking; ++ch) { |
1201 | bool masked_load = is_last_ch && ch == nb_ch_blocking - 1; |
1202 | size_t input_offset = static_cast<size_t>( |
1203 | (input_sp_offset + ch * simd_w_) * sizeof(float)); |
1204 | Vmm vmm_input = get_input_reg( |
1205 | ((overlap + c) % jcp.kw) * jcp.nb_ch_blocking + ch); |
1206 | load_xmm(vmm_input, ptr[reg_tmp_input + input_offset], |
1207 | masked_load); |
1208 | } |
1209 | } |
1210 | } |
1211 | for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) { |
1212 | int io_overlap = i_kw + (i_ur * jcp.stride_w); |
1213 | |
1214 | /* Don't apply FMAs that fall into the padded region */ |
1215 | if (io_overlap - l_pad < 0 |
1216 | || io_overlap - jcp.l_pad >= right_border) |
1217 | continue; |
1218 | |
1219 | const bool over_steps_bdry = is_last_block |
1220 | && (io_overlap - jcp.l_pad + jcp.r_pad > right_border); |
1221 | if (over_steps_bdry) continue; |
1222 | |
1223 | for (int ch = 0; ch < nb_ch_blocking; ++ch) { |
1224 | bool masked_load = is_last_ch && ch == nb_ch_blocking - 1; |
1225 | size_t output_offset = static_cast<size_t>( |
1226 | (output_sp_offset + ch * simd_w_) * sizeof(float)); |
1227 | |
1228 | Vmm vmm_input = get_input_reg( |
1229 | ((io_overlap - l_pad) % jcp.kw) * jcp.nb_ch_blocking |
1230 | + ch); |
1231 | Vmm vmm_acc = get_acc_reg(i_kw * jcp.nb_ch_blocking + ch); |
1232 | if (masked_load) { |
1233 | Vmm vmm_output = get_output_reg(0); |
1234 | load_xmm(vmm_output, ptr[reg_tmp_output + output_offset], |
1235 | true); |
1236 | uni_vfmadd231ps(vmm_acc, vmm_input, vmm_output); |
1237 | } else { |
1238 | uni_vfmadd231ps(vmm_acc, vmm_input, |
1239 | ptr[reg_tmp_output + output_offset]); |
1240 | } |
1241 | } |
1242 | } |
1243 | } |
1244 | } |
1245 | |
1246 | template <cpu_isa_t isa> |
1247 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_unroll_ow_step( |
1248 | int unroll_w, int l_pad, int pad_offset, int ow_block, |
1249 | bool is_last_ch) { |
1250 | |
1251 | const size_t ch_step = is_layout_nxc() ? jcp.ngroups : simd_w_; |
1252 | const int iw_block = ow_block * jcp.stride_w; |
1253 | const int right_border = jcp.iw - iw_block; |
1254 | const int r_pad = jcp.r_pad; |
1255 | const int cascade_input = nstl::min(jcp.stride_w, jcp.kw); |
1256 | |
1257 | /* preamble count for number of cascaded LOAD + FMA operation */ |
1258 | const int input_overlap = nstl::max(jcp.kw - l_pad, 0); |
1259 | const bool is_last_block = (unroll_w + ow_block == jcp.ow); |
1260 | const bool nxc_sse41_offset = is_layout_nxc() && isa == sse41; |
1261 | |
1262 | /* LOAD initial input registers, then cascade LOADs and FMAs*/ |
1263 | for (int r = 0; r < reg_repeats_; ++r) { |
1264 | bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail; |
1265 | bool masked_load |
1266 | = IMPLICATION(isa == sse41, tail_in_first_simd) && is_last_ch; |
1267 | for (int i_ur = 0; i_ur < unroll_w; ++i_ur) { |
1268 | int output_sp_offset = nxc_sse41_offset |
1269 | ? i_ur * ch_step + r * simd_w_ |
1270 | : (i_ur * reg_repeats_ + r) * ch_step; |
1271 | size_t output_offset |
1272 | = static_cast<size_t>(output_sp_offset * sizeof(float)); |
1273 | Vmm vmm_output = get_output_reg(r); |
1274 | load_xmm(vmm_output, ptr[reg_tmp_output + output_offset], |
1275 | masked_load); |
1276 | if (i_ur == 0) { |
1277 | for (int c = 0; c < input_overlap; ++c) { |
1278 | int input_sp = c - pad_offset; |
1279 | int input_sp_offset = nxc_sse41_offset |
1280 | ? input_sp * ch_step + r * simd_w_ |
1281 | : (input_sp * reg_repeats_ + r) * ch_step; |
1282 | if (input_sp_offset < 0 && unroll_w == jcp.ow) continue; |
1283 | |
1284 | const bool over_steps_bdry = true && is_last_block |
1285 | && (c - pad_offset + r_pad > right_border); |
1286 | if (over_steps_bdry) continue; |
1287 | |
1288 | size_t input_offset = static_cast<size_t>( |
1289 | input_sp_offset * sizeof(float)); |
1290 | Vmm vmm_input |
1291 | = get_input_reg((c % jcp.kw) * reg_repeats_ + r); |
1292 | load_xmm(vmm_input, ptr[reg_tmp_input + input_offset], |
1293 | masked_load); |
1294 | } |
1295 | } else { |
1296 | for (int c = 0; c < cascade_input; ++c) { |
1297 | int overlap = (i_ur - 1) * jcp.stride_w + input_overlap; |
1298 | int input_sp = overlap + c - pad_offset; |
1299 | int input_sp_offset = nxc_sse41_offset |
1300 | ? input_sp * ch_step + r * simd_w_ |
1301 | : (input_sp * reg_repeats_ + r) * ch_step; |
1302 | if (input_sp_offset < 0 |
1303 | || overlap + c + l_pad > right_border) |
1304 | continue; |
1305 | |
1306 | const bool over_steps_bdry = true && is_last_block |
1307 | && (overlap + c - pad_offset + r_pad |
1308 | > right_border); |
1309 | if (over_steps_bdry) continue; |
1310 | |
1311 | size_t input_offset = static_cast<size_t>( |
1312 | input_sp_offset * sizeof(float)); |
1313 | Vmm vmm_input = get_input_reg( |
1314 | ((overlap + c) % jcp.kw) * reg_repeats_ + r); |
1315 | load_xmm(vmm_input, ptr[reg_tmp_input + input_offset], |
1316 | masked_load); |
1317 | } |
1318 | } |
1319 | for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) { |
1320 | int io_overlap = i_kw + (i_ur * jcp.stride_w); |
1321 | |
1322 | /* Don't apply FMAs that fall into the padded region */ |
1323 | if (io_overlap - l_pad < 0 |
1324 | || io_overlap - jcp.l_pad >= right_border) |
1325 | continue; |
1326 | |
1327 | const bool over_steps_bdry = is_last_block |
1328 | && (io_overlap - jcp.l_pad + jcp.r_pad > right_border); |
1329 | if (over_steps_bdry) continue; |
1330 | |
1331 | Vmm vmm_input = get_input_reg( |
1332 | ((io_overlap - l_pad) % jcp.kw) * reg_repeats_ + r); |
1333 | Vmm vmm_acc = get_acc_reg(r * jcp.kw + i_kw); |
1334 | Vmm vmm_aux = isa == sse41 ? get_aux_reg() : vmm_input; |
1335 | if (isa == sse41) uni_vmovups(vmm_aux, vmm_input); |
1336 | uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output); |
1337 | } |
1338 | } |
1339 | if (isa == sse41 && masked_load) |
1340 | break; // if tail falls under first simd, skip |
1341 | } |
1342 | } |
1343 | |
1344 | template <cpu_isa_t isa> |
1345 | inline void |
1346 | jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::dispatch_ow_step_unroll( |
1347 | int unroll_w, int l_pad, int pad_offset, int ow_block, |
1348 | int nb_ch_blocking, bool is_last_ch) { |
1349 | if (jcp.is_fast_depthwise) { |
1350 | compute_unroll_ow_step_nxc(unroll_w, l_pad, pad_offset, ow_block, |
1351 | nb_ch_blocking, is_last_ch); |
1352 | } else { |
1353 | assert(nb_ch_blocking == 1); |
1354 | compute_unroll_ow_step( |
1355 | unroll_w, l_pad, pad_offset, ow_block, is_last_ch); |
1356 | } |
1357 | } |
1358 | |
1359 | template <> |
1360 | inline void |
1361 | jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::compute_bias_step_unroll( |
1362 | const int unroll_w, int nb_ch_blocking, bool is_last_ch) { |
1363 | const int ch_step = is_ddst_layout_nxc() ? jcp.ngroups : simd_w_; |
1364 | for (int r = 0; r < reg_repeats_; ++r) { |
1365 | bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail; |
1366 | bool masked_load = tail_in_first_simd && is_last_ch; |
1367 | for (int i = 0; i < unroll_w; ++i) { |
1368 | int off_output = is_ddst_layout_nxc() |
1369 | ? i * ch_step + r * simd_w_ |
1370 | : (i * reg_repeats_ + r) * ch_step; |
1371 | Vmm vmm_bias = get_bias_reg(r); |
1372 | Vmm vmm_out = get_output_reg(1 + r); |
1373 | addps_xmm(vmm_bias, vmm_out, |
1374 | vmmword[reg_tmp_output + off_output * sizeof(float)], |
1375 | masked_load); |
1376 | } |
1377 | if (masked_load) break; // if tail falls under first simd, skip |
1378 | } |
1379 | } |
1380 | |
1381 | template <cpu_isa_t isa> |
1382 | inline void |
1383 | jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_step_unroll( |
1384 | const int unroll_w, int nb_ch_blocking, bool is_last_ch) { |
1385 | const int ch_step = is_ddst_layout_nxc() ? jcp.ngroups : simd_w_; |
1386 | for (int i = 0; i < unroll_w; ++i) { |
1387 | for (int ch = 0; ch < nb_ch_blocking; ++ch) { |
1388 | Vmm vmm_bias = get_bias_reg(ch); |
1389 | size_t off_output = static_cast<size_t>( |
1390 | (i * ch_step + ch * simd_w_) * sizeof(float)); |
1391 | bool masked_store = is_last_ch && (ch == nb_ch_blocking - 1); |
1392 | bool use_extra_vmm = isa == avx2 && masked_store; |
1393 | Vmm vmm_out = use_extra_vmm ? get_output_reg(1) : vmm_bias; |
1394 | addps_xmm(vmm_bias, vmm_out, vmmword[reg_tmp_output + off_output], |
1395 | masked_store); |
1396 | } |
1397 | } |
1398 | } |
1399 | |
1400 | template <> |
1401 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::store_filter( |
1402 | int nb_ch_blocking, bool is_last_ch) { |
1403 | assert(nb_ch_blocking == 1); |
1404 | for (int r = 0; r < reg_repeats_; ++r) { |
1405 | bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail; |
1406 | bool masked_load = tail_in_first_simd && is_last_ch; |
1407 | const int reg_set = r * jcp.kw; |
1408 | for (int i = 0; i < jcp.kw; ++i) { |
1409 | size_t off_filter = static_cast<size_t>( |
1410 | (i * jcp.ch_block + r * simd_w_) * sizeof(float)); |
1411 | Vmm vmm_acc = get_acc_reg(i + reg_set); |
1412 | store_xmm( |
1413 | vmm_acc, vmmword[reg_tmp_filter + off_filter], masked_load); |
1414 | } |
1415 | if (masked_load) break; // if tail falls under first simd, skip |
1416 | } |
1417 | } |
1418 | |
1419 | template <cpu_isa_t isa> |
1420 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_filter( |
1421 | int nb_ch_blocking, bool is_last_ch) { |
1422 | size_t filter_step = jcp.kh * jcp.kw; |
1423 | for (int ch = 0; ch < nb_ch_blocking; ++ch) { |
1424 | bool masked_store = is_last_ch && ch == nb_ch_blocking - 1; |
1425 | for (int i = 0; i < jcp.kw; ++i) { |
1426 | size_t off_filter = static_cast<size_t>( |
1427 | (ch * filter_step + i) * jcp.ch_block * sizeof(float)); |
1428 | Vmm vmm_acc = get_acc_reg(i * jcp.nb_ch_blocking + ch); |
1429 | store_xmm(vmm_acc, vmmword[reg_tmp_filter + off_filter], |
1430 | masked_store); |
1431 | } |
1432 | } |
1433 | } |
1434 | |
1435 | template <> |
1436 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>::store_bias( |
1437 | int nb_ch_blocking, bool is_last_ch) { |
1438 | for (int r = 0; r < reg_repeats_; ++r) { |
1439 | bool tail_in_first_simd = (r + 1) * simd_w_ >= jcp.ch_tail; |
1440 | bool masked_load = tail_in_first_simd && is_last_ch; |
1441 | size_t half_ch_block_offset |
1442 | = static_cast<size_t>(r * simd_w_ * sizeof(float)); |
1443 | Vmm vmm_bias = get_bias_reg(r); |
1444 | store_xmm(vmm_bias, vmmword[reg_bias_baddr + half_ch_block_offset], |
1445 | masked_load); |
1446 | if (masked_load) break; // if tail falls under first simd, skip |
1447 | } |
1448 | } |
1449 | |
1450 | template <cpu_isa_t isa> |
1451 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_bias( |
1452 | int nb_ch_blocking, bool is_last_ch) { |
1453 | for (int ch = 0; ch < nb_ch_blocking; ++ch) { |
1454 | bool masked_store = is_last_ch && ch == nb_ch_blocking - 1; |
1455 | size_t bias_offset = static_cast<size_t>(ch * simd_w_ * sizeof(float)); |
1456 | Vmm vmm_bias = get_bias_reg(ch); |
1457 | store_xmm( |
1458 | vmm_bias, vmmword[reg_bias_baddr + bias_offset], masked_store); |
1459 | } |
1460 | } |
1461 | |
1462 | template <cpu_isa_t isa> |
1463 | inline void |
1464 | jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_spatial_loop_bias( |
1465 | int nb_ch_blocking, bool is_last_ch) { |
1466 | Label oh_label; |
1467 | Label ow_blk_label; |
1468 | |
1469 | const int unroll_w = nstl::min(max_unroll_w_, jcp.ow); |
1470 | const int unroll_w_trips = jcp.ow / unroll_w; |
1471 | const int tail_w = jcp.ow > max_unroll_w_ ? jcp.ow % max_unroll_w_ : 0; |
1472 | |
1473 | const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
1474 | const size_t ch_offset = ch_step * sizeof(float); |
1475 | |
1476 | mov(reg_oh, ptr[this->param1 + GET_OFF(oh_index)]); |
1477 | mov(reg_oh_worksize, ptr[this->param1 + GET_OFF(oh_count)]); |
1478 | |
1479 | mov(reg_tmp_output, reg_output_baddr); |
1480 | L(oh_label); |
1481 | { |
1482 | |
1483 | mov(reg_iter_ow_blk, unroll_w_trips); |
1484 | L(ow_blk_label); |
1485 | { |
1486 | compute_bias_step_unroll(unroll_w, nb_ch_blocking, is_last_ch); |
1487 | add(reg_tmp_output, unroll_w * ch_offset); |
1488 | |
1489 | dec(reg_iter_ow_blk); |
1490 | cmp(reg_iter_ow_blk, 0); |
1491 | jg(ow_blk_label, T_NEAR); |
1492 | } |
1493 | |
1494 | if (tail_w > 0) { |
1495 | compute_bias_step_unroll(tail_w, nb_ch_blocking, is_last_ch); |
1496 | add(reg_tmp_output, tail_w * ch_offset); |
1497 | } |
1498 | |
1499 | inc(reg_oh); |
1500 | cmp(reg_oh, reg_oh_worksize); |
1501 | jl(oh_label, T_NEAR); |
1502 | } |
1503 | } |
1504 | |
1505 | template <cpu_isa_t isa> |
1506 | void jit_uni_dw_conv_bwd_weights_kernel_f32< |
1507 | isa>::compute_single_ch_block_bias() { |
1508 | |
1509 | auto write_compute_bias = [&](bool is_last_ch) { |
1510 | Label skip_load_bias; |
1511 | |
1512 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1513 | and_(reg_exec_flags, FLAG_ZERO_BIAS); |
1514 | test(reg_exec_flags, reg_exec_flags); |
1515 | jne(skip_load_bias); |
1516 | |
1517 | assert(jcp.nb_ch_blocking == 1); |
1518 | load_bias(jcp.nb_ch_blocking, is_last_ch); |
1519 | |
1520 | L(skip_load_bias); |
1521 | compute_spatial_loop_bias(jcp.nb_ch_blocking, is_last_ch); |
1522 | |
1523 | store_bias(jcp.nb_ch_blocking, is_last_ch); |
1524 | }; |
1525 | |
1526 | Label skip_masked_bias_label, done_bias_label; |
1527 | |
1528 | zero_bias(); |
1529 | |
1530 | bool do_bias_ch_tail = jcp.ch_tail > 0; |
1531 | if (do_bias_ch_tail) { |
1532 | // test last channel |
1533 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1534 | and_(reg_exec_flags, FLAG_OC_LAST); |
1535 | test(reg_exec_flags, reg_exec_flags); |
1536 | jz(skip_masked_bias_label, T_NEAR); |
1537 | |
1538 | write_compute_bias(true); |
1539 | |
1540 | jmp(done_bias_label, T_NEAR); |
1541 | L(skip_masked_bias_label); |
1542 | } |
1543 | |
1544 | write_compute_bias(false); |
1545 | |
1546 | L(done_bias_label); |
1547 | } |
1548 | |
1549 | template <cpu_isa_t isa> |
1550 | void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ch_loop_bias( |
1551 | bool do_load_bias) { |
1552 | |
1553 | assert(is_ddst_layout_nxc()); |
1554 | |
1555 | auto write_compute_bias = [&](int nb_ch_blocking, bool is_last_ch) { |
1556 | if (do_load_bias) |
1557 | load_bias(nb_ch_blocking, is_last_ch); |
1558 | else |
1559 | zero_bias(); |
1560 | compute_spatial_loop_bias(nb_ch_blocking, is_last_ch); |
1561 | store_bias(nb_ch_blocking, is_last_ch); |
1562 | }; |
1563 | |
1564 | if (jcp.nb_ch > jcp.nb_ch_blocking) { |
1565 | |
1566 | Label ch_loop_label; |
1567 | const bool masked_ch_tail = jcp.ch_tail > 0; |
1568 | const int nb_ch_blocking_tail = jcp.nb_ch % jcp.nb_ch_blocking; |
1569 | const bool unroll_last_ch_block |
1570 | = nb_ch_blocking_tail > 0 || masked_ch_tail; |
1571 | const int last_ch_block = nb_ch_blocking_tail > 0 ? nb_ch_blocking_tail |
1572 | : jcp.nb_ch_blocking; |
1573 | |
1574 | push(reg_output_baddr); |
1575 | |
1576 | Label last_ch_block_label, ch_block_done_label; |
1577 | if (unroll_last_ch_block) { |
1578 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1579 | and_(reg_exec_flags, FLAG_OC_LAST); |
1580 | test(reg_exec_flags, reg_exec_flags); |
1581 | jnz(last_ch_block_label, T_NEAR); |
1582 | } |
1583 | |
1584 | write_compute_bias(jcp.nb_ch_blocking, false); |
1585 | |
1586 | if (unroll_last_ch_block) { |
1587 | jmp(ch_block_done_label, T_NEAR); |
1588 | |
1589 | L(last_ch_block_label); |
1590 | write_compute_bias(last_ch_block, masked_ch_tail); |
1591 | L(ch_block_done_label); |
1592 | } |
1593 | |
1594 | pop(reg_output_baddr); |
1595 | |
1596 | } else { |
1597 | bool masked_ch_tail = jcp.ch_tail > 0; |
1598 | write_compute_bias(jcp.nb_ch_blocking, masked_ch_tail); |
1599 | } |
1600 | } |
1601 | |
1602 | template <cpu_isa_t isa> |
1603 | void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::deploy_ch_loop_bias() { |
1604 | |
1605 | Label ch_loop_label, zero_bias_label, load_bias_done_label; |
1606 | |
1607 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1608 | and_(reg_exec_flags, FLAG_ZERO_BIAS); |
1609 | test(reg_exec_flags, reg_exec_flags); |
1610 | jne(zero_bias_label, T_NEAR); |
1611 | |
1612 | compute_ch_loop_bias(true); // load_bias |
1613 | jmp(load_bias_done_label, T_NEAR); |
1614 | |
1615 | L(zero_bias_label); |
1616 | compute_ch_loop_bias(false); // zero_bias |
1617 | |
1618 | L(load_bias_done_label); |
1619 | } |
1620 | |
1621 | template <cpu_isa_t isa> |
1622 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias() { |
1623 | |
1624 | mov(reg_bias_baddr, ptr[this->param1 + GET_OFF(bias)]); |
1625 | |
1626 | if (is_ddst_layout_nxc()) |
1627 | deploy_ch_loop_bias(); |
1628 | else |
1629 | compute_single_ch_block_bias(); |
1630 | } |
1631 | |
1632 | template <cpu_isa_t isa> |
1633 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter_kh_loop( |
1634 | int nb_ch_blocking) { |
1635 | |
1636 | const size_t filter_offset_kw = jcp.kw * jcp.ch_block * sizeof(float); |
1637 | const size_t filter_offset_kh = jcp.kh * filter_offset_kw; |
1638 | |
1639 | Label kh_loop_label; |
1640 | |
1641 | mov(reg_kh_aux, jcp.kh); |
1642 | L(kh_loop_label); |
1643 | { |
1644 | store_filter(nb_ch_blocking); |
1645 | |
1646 | add(reg_tmp_filter, filter_offset_kw); |
1647 | dec(reg_kh_aux); |
1648 | cmp(reg_kh_aux, 0); |
1649 | jg(kh_loop_label, T_NEAR); |
1650 | } |
1651 | |
1652 | /* Comeback pointers */ |
1653 | sub(reg_tmp_filter, filter_offset_kh); |
1654 | } |
1655 | |
1656 | template <cpu_isa_t isa> |
1657 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter_ch_loop() { |
1658 | |
1659 | bool write_ch_blocking_unroll |
1660 | = is_layout_nxc() && jcp.nb_ch > jcp.nb_ch_blocking; |
1661 | if (write_ch_blocking_unroll) { |
1662 | const int nb_ch_blocking_tail = jcp.nb_ch % jcp.nb_ch_blocking; |
1663 | |
1664 | Label last_ch_block_label, ch_block_done_label; |
1665 | |
1666 | if (nb_ch_blocking_tail) { |
1667 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1668 | and_(reg_exec_flags, FLAG_OC_LAST); |
1669 | test(reg_exec_flags, reg_exec_flags); |
1670 | jnz(last_ch_block_label, T_NEAR); |
1671 | } |
1672 | |
1673 | zero_filter_kh_loop(jcp.nb_ch_blocking); |
1674 | |
1675 | if (nb_ch_blocking_tail) { |
1676 | jmp(ch_block_done_label, T_NEAR); |
1677 | |
1678 | L(last_ch_block_label); |
1679 | zero_filter_kh_loop(nb_ch_blocking_tail); |
1680 | L(ch_block_done_label); |
1681 | } |
1682 | } else { |
1683 | zero_filter_kh_loop(jcp.nb_ch_blocking); |
1684 | } |
1685 | } |
1686 | |
1687 | template <cpu_isa_t isa> |
1688 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::deploy_zero_filter() { |
1689 | |
1690 | Label skip_zeroing_label; |
1691 | |
1692 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1693 | and_(reg_exec_flags, FLAG_ZERO_FILTER); |
1694 | test(reg_exec_flags, reg_exec_flags); |
1695 | je(skip_zeroing_label, T_NEAR); |
1696 | |
1697 | zero_filter(); |
1698 | |
1699 | mov(reg_tmp_filter, reg_filter_baddr); |
1700 | zero_filter_ch_loop(); |
1701 | |
1702 | L(skip_zeroing_label); |
1703 | } |
1704 | |
1705 | template <cpu_isa_t isa> |
1706 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_kh_step( |
1707 | int unroll_w, int l_pad, int pad_offset, int ow_block, |
1708 | int nb_ch_blocking, bool is_last_ch) { |
1709 | |
1710 | const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
1711 | const size_t input_offset = jcp.iw * ch_step * sizeof(float); |
1712 | const size_t filter_offset = jcp.kw * jcp.ch_block * sizeof(float); |
1713 | |
1714 | Label kh_loop_label, skip_loop_label; |
1715 | |
1716 | cmp(reg_kh, 0); |
1717 | je(skip_loop_label, T_NEAR); |
1718 | |
1719 | mov(reg_kh_aux, reg_kh); |
1720 | L(kh_loop_label); |
1721 | { |
1722 | load_filter(nb_ch_blocking, is_last_ch); |
1723 | dispatch_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block, |
1724 | nb_ch_blocking, is_last_ch); |
1725 | store_filter(nb_ch_blocking, is_last_ch); |
1726 | |
1727 | add(reg_tmp_filter, filter_offset); |
1728 | add(reg_tmp_input, input_offset); |
1729 | dec(reg_kh_aux); |
1730 | cmp(reg_kh_aux, 0); |
1731 | jg(kh_loop_label, T_NEAR); |
1732 | } |
1733 | |
1734 | /* Comeback pointers */ |
1735 | Label kh_comeback_label; |
1736 | mov(reg_kh_aux, reg_kh); |
1737 | L(kh_comeback_label); |
1738 | { |
1739 | sub(reg_tmp_input, input_offset); |
1740 | sub(reg_tmp_filter, filter_offset); |
1741 | dec(reg_kh_aux); |
1742 | cmp(reg_kh_aux, 0); |
1743 | jg(kh_comeback_label, T_NEAR); |
1744 | } |
1745 | |
1746 | L(skip_loop_label); |
1747 | } |
1748 | |
1749 | template <cpu_isa_t isa> |
1750 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ch_loop( |
1751 | int unroll_w, int l_pad, int pad_offset, int ow_block) { |
1752 | |
1753 | bool write_ch_blocking_unroll |
1754 | = is_layout_nxc() && jcp.nb_ch > jcp.nb_ch_blocking; |
1755 | if (write_ch_blocking_unroll) { |
1756 | |
1757 | const bool masked_ch_tail = jcp.ch_tail > 0; |
1758 | const int nb_ch_blocking_tail = jcp.nb_ch % jcp.nb_ch_blocking; |
1759 | const int last_ch_block = nb_ch_blocking_tail > 0 ? nb_ch_blocking_tail |
1760 | : jcp.nb_ch_blocking; |
1761 | const bool unroll_last_ch_block |
1762 | = nb_ch_blocking_tail > 0 || masked_ch_tail; |
1763 | |
1764 | Label last_ch_block_label, ch_block_done_label; |
1765 | if (unroll_last_ch_block) { |
1766 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1767 | and_(reg_exec_flags, FLAG_OC_LAST); |
1768 | test(reg_exec_flags, reg_exec_flags); |
1769 | jnz(last_ch_block_label, T_NEAR); |
1770 | } |
1771 | |
1772 | compute_kh_step(unroll_w, l_pad, pad_offset, ow_block, |
1773 | jcp.nb_ch_blocking, false); |
1774 | |
1775 | if (unroll_last_ch_block) { |
1776 | jmp(ch_block_done_label, T_NEAR); |
1777 | |
1778 | L(last_ch_block_label); |
1779 | compute_kh_step(unroll_w, l_pad, pad_offset, ow_block, |
1780 | last_ch_block, masked_ch_tail); |
1781 | L(ch_block_done_label); |
1782 | } |
1783 | } else { |
1784 | bool masked_ch_tail = jcp.ch_tail > 0 && is_layout_nxc(); |
1785 | compute_kh_step(unroll_w, l_pad, pad_offset, ow_block, |
1786 | jcp.nb_ch_blocking, masked_ch_tail); |
1787 | } |
1788 | } |
1789 | |
1790 | template <cpu_isa_t isa> |
1791 | inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_loop( |
1792 | int unroll_w, int l_pad, int pad_offset, int ow_block) { |
1793 | |
1794 | mov(reg_tmp_output, reg_output_baddr); |
1795 | mov(reg_tmp_input, reg_input_baddr); |
1796 | mov(reg_tmp_filter, reg_filter_baddr); |
1797 | |
1798 | const int input_bottom_padding_overlap |
1799 | = div_up(jcp.ih + jcp.t_pad - (jcp.kh - 1), jcp.stride_h); |
1800 | |
1801 | const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
1802 | const size_t typesize = sizeof(float); |
1803 | const size_t input_shift = typesize * jcp.iw * ch_step; |
1804 | const size_t output_shift = typesize * jcp.ow * ch_step; |
1805 | const size_t filter_shift = typesize * jcp.kw * jcp.ch_block; |
1806 | |
1807 | Label loop_begin_label, loop_end_label, common_block_label, |
1808 | top_padding_end_label, bottom_padding_end_label, |
1809 | bottom_padding_label; |
1810 | |
1811 | mov(reg_oh, ptr[this->param1 + GET_OFF(oh_index)]); |
1812 | mov(reg_kh, ptr[this->param1 + GET_OFF(kh_count)]); |
1813 | |
1814 | // replacement for 'os_index_end' |
1815 | mov(reg_oh_worksize, ptr[this->param1 + GET_OFF(oh_count)]); |
1816 | |
1817 | cmp(reg_kh, 0); |
1818 | jle(loop_end_label, T_NEAR); // no iterations along kh |
1819 | cmp(reg_oh, reg_oh_worksize); |
1820 | jge(loop_end_label, T_NEAR); // no iterations along height dimension |
1821 | |
1822 | L(loop_begin_label); |
1823 | |
1824 | compute_ch_loop(unroll_w, l_pad, pad_offset, ow_block); |
1825 | |
1826 | /* Compute 'top' edge */ |
1827 | if (jcp.t_pad > 0) { |
1828 | |
1829 | /* Check if within top padding region */ |
1830 | cmp(reg_oh, div_up(jcp.t_pad, jcp.stride_h)); |
1831 | jge(top_padding_end_label, T_NEAR); |
1832 | |
1833 | /* Increment step counter and adjust filter position */ |
1834 | sub(reg_tmp_filter, filter_shift * jcp.stride_h); |
1835 | add(reg_kh, jcp.stride_h); |
1836 | |
1837 | /* Final number of kernel elements that overlap with input */ |
1838 | const int inp_ker_overlap = nstl::min(jcp.kh, jcp.ih); |
1839 | cmp(reg_kh, inp_ker_overlap); |
1840 | jle(common_block_label, T_NEAR); |
1841 | |
1842 | /* Correct any excess shifts to kernel and input */ |
1843 | if (jcp.t_pad <= jcp.oh * jcp.stride_h) { |
1844 | /* Filter has moved beyond padding (adjust for stride effects) */ |
1845 | if (jcp.t_pad % jcp.stride_h != 0) { |
1846 | int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h; |
1847 | add(reg_tmp_filter, filter_shift * inp_corr); |
1848 | add(reg_tmp_input, input_shift * inp_corr); |
1849 | } |
1850 | } else { |
1851 | /* Filter still overlaps padding (complete reset) */ |
1852 | sub(reg_tmp_filter, |
1853 | (jcp.t_pad - jcp.oh * jcp.stride_h) * filter_shift); |
1854 | } |
1855 | |
1856 | /* Apply correction */ |
1857 | mov(reg_kh, inp_ker_overlap); |
1858 | jmp(common_block_label); |
1859 | |
1860 | L(top_padding_end_label); |
1861 | } |
1862 | |
1863 | /* Compute 'bottom' edge */ |
1864 | if (jcp.b_pad > 0) { |
1865 | |
1866 | /* Check if within bottom padding region */ |
1867 | cmp(reg_oh, input_bottom_padding_overlap - 1); |
1868 | jl(bottom_padding_end_label, T_NEAR); |
1869 | jg(bottom_padding_label, T_NEAR); |
1870 | |
1871 | /* Execute overlap correction between the filter and the initial |
1872 | * bottom padding region. */ |
1873 | mov(reg_kh, |
1874 | jcp.ih + jcp.t_pad |
1875 | - input_bottom_padding_overlap * jcp.stride_h); |
1876 | jmp(bottom_padding_end_label, T_NEAR); |
1877 | |
1878 | L(bottom_padding_label); |
1879 | sub(reg_kh, jcp.stride_h); |
1880 | cmp(reg_kh, 0); |
1881 | jle(loop_end_label, T_NEAR); |
1882 | |
1883 | L(bottom_padding_end_label); |
1884 | } |
1885 | |
1886 | /* Compute middle block */ |
1887 | add(reg_tmp_input, input_shift * jcp.stride_h); |
1888 | |
1889 | /* Execute common block and loop */ |
1890 | L(common_block_label); |
1891 | add(reg_tmp_output, output_shift); |
1892 | inc(reg_oh); |
1893 | cmp(reg_oh, reg_oh_worksize); |
1894 | jl(loop_begin_label, T_NEAR); |
1895 | |
1896 | L(loop_end_label); |
1897 | } |
1898 | |
1899 | template <cpu_isa_t isa> |
1900 | void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::calculate_w_unrolling( |
1901 | int &unroll_trips, int &unroll_w, int &unroll_w_tail) { |
1902 | |
1903 | const bool do_unroll_w = jcp.ow > max_unroll_w_; |
1904 | if (do_unroll_w) { |
1905 | unroll_w = nstl::min(block_size_, jcp.ow); |
1906 | unroll_trips = jcp.ow / unroll_w; |
1907 | /* calculate tail */ |
1908 | unroll_w_tail = jcp.ow % unroll_w; |
1909 | /* Perform some rebalancing if tail too small*/ |
1910 | if ((unroll_w_tail == 0 && jcp.r_pad != 0) |
1911 | || (jcp.r_pad > 0 && jcp.r_pad >= unroll_w_tail)) { |
1912 | if (unroll_trips > 1) { |
1913 | unroll_w_tail += unroll_w; |
1914 | unroll_trips--; |
1915 | } else { |
1916 | /* Idealy, this case shouldn't happen */ |
1917 | unroll_w_tail += (unroll_w - unroll_w / 2); |
1918 | unroll_w = unroll_w / 2; |
1919 | } |
1920 | } |
1921 | } else { |
1922 | unroll_w_tail = jcp.ow; |
1923 | } |
1924 | } |
1925 | |
1926 | template <cpu_isa_t isa> |
1927 | inline void |
1928 | jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() { |
1929 | |
1930 | Label ow_blk_label; // for computing 'ow middle' block |
1931 | int pad_offset = 0; |
1932 | int l_pad = jcp.l_pad; |
1933 | |
1934 | int unroll_w_tail = 0; |
1935 | int unroll_w = 0; |
1936 | int unroll_trips = 0; |
1937 | calculate_w_unrolling(unroll_trips, unroll_w, unroll_w_tail); |
1938 | |
1939 | const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
1940 | const size_t data_offset = unroll_w * ch_step * sizeof(float); |
1941 | |
1942 | if (jcp.with_bias) compute_bias(); |
1943 | |
1944 | /* Pass filter address, then offset for h_padding. */ |
1945 | deploy_zero_filter(); |
1946 | mov(reg_kh_offset, ptr[this->param1 + GET_OFF(filter_pad_off)]); |
1947 | add(reg_filter_baddr, reg_kh_offset); |
1948 | |
1949 | /* compute left padded block */ |
1950 | const bool do_unroll_w = jcp.ow > max_unroll_w_; |
1951 | if (l_pad && do_unroll_w) { |
1952 | compute_h_loop(unroll_w, l_pad, 0, 0); |
1953 | add(reg_output_baddr, data_offset); |
1954 | add(reg_input_baddr, data_offset * jcp.stride_w); |
1955 | unroll_trips--; |
1956 | pad_offset = l_pad; |
1957 | l_pad = 0; |
1958 | } |
1959 | |
1960 | /* Insert loop for 'ow' block when middle block needs to execute more |
1961 | * than once */ |
1962 | bool do_ow_blk_loop = unroll_trips > 1; |
1963 | if (do_ow_blk_loop) { |
1964 | mov(reg_iter_ow_blk, unroll_trips); |
1965 | L(ow_blk_label); |
1966 | } |
1967 | if (unroll_trips > 0) { |
1968 | compute_h_loop(unroll_w, l_pad, pad_offset, 0); |
1969 | add(reg_output_baddr, data_offset); |
1970 | add(reg_input_baddr, data_offset * jcp.stride_w); |
1971 | } |
1972 | if (do_ow_blk_loop) { |
1973 | dec(reg_iter_ow_blk); |
1974 | cmp(reg_iter_ow_blk, 0); |
1975 | jg(ow_blk_label, T_NEAR); |
1976 | } |
1977 | |
1978 | /* compute right padded block */ |
1979 | if (unroll_w_tail) { |
1980 | compute_h_loop( |
1981 | unroll_w_tail, l_pad, pad_offset, jcp.ow - unroll_w_tail); |
1982 | } |
1983 | } |
1984 | |
1985 | template <cpu_isa_t isa> |
1986 | void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() { |
1987 | assert(is_src_layout_nxc() == is_ddst_layout_nxc()); |
1988 | |
1989 | preamble(); |
1990 | |
1991 | mov(reg_input_baddr, ptr[this->param1 + GET_OFF(input)]); |
1992 | mov(reg_output_baddr, ptr[this->param1 + GET_OFF(output)]); |
1993 | mov(reg_filter_baddr, ptr[this->param1 + GET_OFF(filter)]); |
1994 | |
1995 | bool set_kmask = isa > avx2 && jcp.ch_tail > 0 |
1996 | && (jcp.with_bias || is_layout_nxc()); |
1997 | if (set_kmask) { |
1998 | // Prepare masks for tail |
1999 | Reg32 reg_tmp_32 = reg_tmp.cvt32(); |
2000 | mov(reg_tmp_32, (1 << jcp.ch_tail) - 1); |
2001 | kmovw(k_ch_tail_mask, reg_tmp_32); |
2002 | } |
2003 | |
2004 | compute_ow_block_unroll(); |
2005 | |
2006 | this->postamble(); |
2007 | } |
2008 | #undef GET_OFF |
2009 | |
2010 | template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_core>; |
2011 | template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>; |
2012 | template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sse41>; |
2013 | |
2014 | } // namespace x64 |
2015 | } // namespace cpu |
2016 | } // namespace impl |
2017 | } // namespace dnnl |
2018 | |