1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/nstl.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
23 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
24 | #include "cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp" |
25 | |
26 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | using namespace Xbyak; |
34 | using namespace dnnl::impl::utils; |
35 | |
36 | jit_avx512_dw_conv_fwd_kernel_bf16::jit_avx512_dw_conv_fwd_kernel_bf16( |
37 | const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md) |
38 | : jit_generator(jit_name()), jcp(ajcp) { |
39 | if (jcp.with_eltwise || jcp.with_binary) { |
40 | using namespace binary_injector; |
41 | static constexpr bool preserve_gpr = true; |
42 | static constexpr bool preserve_vmm = false; |
43 | static constexpr size_t helper_vmm_idx = 31; |
44 | static constexpr bool use_exact_tail_scalar_bcast = true; |
45 | const size_t tail_size = jcp.oc_without_padding |
46 | % (cpu_isa_traits<avx512_core>::vlen / sizeof(float)); |
47 | |
48 | const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, |
49 | r14, r15, r12, preserve_gpr, preserve_vmm, |
50 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
51 | memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask, |
52 | use_exact_tail_scalar_bcast}; |
53 | const static_params_t static_params { |
54 | this->param1, rhs_arg_static_params}; |
55 | |
56 | postops_injector_ = utils::make_unique< |
57 | injector::jit_uni_postops_injector_t<avx512_core>>( |
58 | this, jcp.post_ops, static_params); |
59 | } |
60 | if (!isa_has_bf16(jcp.isa)) |
61 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
62 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
63 | bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_6); |
64 | } |
65 | |
66 | int jit_avx512_dw_conv_fwd_kernel_bf16::get_acc_reg_idx(int idx) const { |
67 | assert(idx + acc_idx_start <= get_max_regs()); |
68 | return idx + acc_idx_start; |
69 | } |
70 | |
71 | Xbyak::Zmm jit_avx512_dw_conv_fwd_kernel_bf16::get_acc_reg(int idx) { |
72 | return Xbyak::Zmm(get_acc_reg_idx(idx)); |
73 | } |
74 | |
75 | void jit_avx512_dw_conv_fwd_kernel_bf16::load_src( |
76 | int ur_ch_blocks, int ur_w, bool last_ch_block_flag) { |
77 | |
78 | const auto dst_layout_nxc = is_dst_layout_nxc(); |
79 | const auto ch_blk = jcp.ch_block; |
80 | const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk; |
81 | const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; |
82 | |
83 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
84 | const bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1; |
85 | for (int ow = 0; ow < ur_w; ow++) { |
86 | Zmm zmm_acc = get_acc_reg(ch * ur_w + ow); |
87 | const Zmm zmm_acc_msk |
88 | = mask_flag ? zmm_acc | ktail_mask | T_z : zmm_acc; |
89 | |
90 | if (this->jcp.with_bias) { |
91 | int b_off = ch * ch_blk; |
92 | uni_vmovups( |
93 | zmm_acc_msk, vmmword[reg_bias + b_off * sizeof(float)]); |
94 | } else { |
95 | uni_vpxor(zmm_acc, zmm_acc, zmm_acc); |
96 | } |
97 | if (this->jcp.with_sum) { |
98 | int o_off = ch * ocb_stride + ow * ow_stride; |
99 | if (jcp.dst_dt == data_type::bf16) { |
100 | const Zmm zmm_prev_dst_msk = mask_flag |
101 | ? zmm_prev_dst | ktail_mask | T_z |
102 | : zmm_prev_dst; |
103 | vpmovzxwd(zmm_prev_dst_msk, |
104 | vmmword[reg_output + o_off * jcp.typesize_out]); |
105 | vpslld(zmm_prev_dst, zmm_prev_dst, 16); |
106 | vaddps(zmm_acc, zmm_prev_dst); |
107 | } else { |
108 | uni_vaddps(zmm_acc_msk, zmm_acc_msk, |
109 | vmmword[reg_output + o_off * jcp.typesize_out]); |
110 | } |
111 | } |
112 | } |
113 | } |
114 | } |
115 | |
116 | void jit_avx512_dw_conv_fwd_kernel_bf16::apply_filter_unrolled(int ur_ch_blocks, |
117 | int ur_w, int pad_l, int pad_r, bool last_ch_block_flag) { |
118 | int ch_blk = jcp.ch_block; |
119 | int dilate_h = jcp.dilate_h + 1; |
120 | int dilate_w = jcp.dilate_w + 1; |
121 | int stride_w = jcp.stride_w; |
122 | |
123 | const auto src_layout_nxc = is_src_layout_nxc(); |
124 | const auto iw_stride = src_layout_nxc ? jcp.ngroups : ch_blk; |
125 | const auto ih_stride = jcp.iw * iw_stride; |
126 | const auto icb_stride = src_layout_nxc |
127 | ? ch_blk |
128 | : (jcp.is_fused_conv ? 1 : jcp.ih) * jcp.iw * ch_blk; |
129 | |
130 | Label iter_exit_label; |
131 | |
132 | cmp(reg_kh, 0); |
133 | je(iter_exit_label, T_NEAR); |
134 | |
135 | mov(iter_kh, reg_kh); |
136 | Label kh_label; |
137 | L(kh_label); |
138 | { |
139 | if (jcp.is_fused_conv) { |
140 | mov(aux_reg_input, ptr[aux_reg_input_buffer_ptr]); |
141 | add(aux_reg_input, reg_iw_offset); |
142 | } |
143 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
144 | const bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1; |
145 | for (int kw = 0; kw < jcp.kw; kw++) { |
146 | int ker_off = ch * jcp.kh * jcp.kw * ch_blk + kw * ch_blk; |
147 | const Zmm zmm_ker_reg_msk = mask_flag |
148 | ? zmm_ker_reg | ktail_mask | T_z |
149 | : zmm_ker_reg; |
150 | vpmovzxwd(zmm_ker_reg_msk, |
151 | ptr[aux_reg_kernel + ker_off * jcp.typesize_in]); |
152 | int ow_start = get_ow_start(kw, pad_l); |
153 | int ow_end = get_ow_end(ur_w, kw, pad_r); |
154 | for (int ow = ow_start; ow < ow_end; ow++) { |
155 | const Zmm zmm_src_reg_msk = mask_flag |
156 | ? zmm_src_reg | ktail_mask | T_z |
157 | : zmm_src_reg; |
158 | Zmm zmm_acc = get_acc_reg(ch * ur_w + ow); |
159 | int inp_off = ch * icb_stride |
160 | + (ow * stride_w - pad_l) * iw_stride |
161 | + kw * dilate_w * iw_stride; |
162 | /* zero-extend bf16 to packed 32-bit int */ |
163 | vpmovzxwd(zmm_src_reg_msk, |
164 | ptr[aux_reg_input + inp_off * jcp.typesize_in]); |
165 | if (isa_has_bf16(jcp.isa)) |
166 | vdpbf16ps(zmm_acc, zmm_ker_reg, zmm_src_reg); |
167 | else |
168 | bf16_emu_->vdpbf16ps(zmm_acc, zmm_ker_reg, zmm_src_reg); |
169 | } |
170 | } |
171 | } |
172 | |
173 | add(aux_reg_kernel, jcp.kw * ch_blk * jcp.typesize_in); |
174 | if (jcp.is_fused_conv) { |
175 | // Move to next row pointer in the buffer |
176 | add(aux_reg_input_buffer_ptr, sizeof(void *)); |
177 | } else { |
178 | add(aux_reg_input, ih_stride * dilate_h * jcp.typesize_in); |
179 | } |
180 | |
181 | dec(iter_kh); |
182 | cmp(iter_kh, 0); |
183 | jg(kh_label, T_NEAR); |
184 | } |
185 | |
186 | L(iter_exit_label); |
187 | } |
188 | |
189 | template <typename F> |
190 | static void iterate(const int ur_ch_blocks, const int ur_w, |
191 | const bool mask_tail, const F &f) { |
192 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
193 | const bool mask_flag = mask_tail && ch + 1 == ur_ch_blocks; |
194 | for (int ow = 0; ow < ur_w; ow++) |
195 | f(ch, ow, mask_flag); |
196 | } |
197 | } |
198 | template <typename F> |
199 | static void iterate(const int ur_ch_blocks, const int ur_w, const F &f) { |
200 | iterate(ur_ch_blocks, ur_w, false, f); |
201 | } |
202 | |
203 | void jit_avx512_dw_conv_fwd_kernel_bf16::apply_postops( |
204 | int ur_ch_blocks, int ur_w, bool last_ch_block_flag) { |
205 | if (this->jcp.with_eltwise || this->jcp.with_binary) { |
206 | |
207 | injector_utils::vmm_index_set_t vmm_idxs; |
208 | if (jcp.with_binary) { |
209 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, |
210 | rhs_arg_params_tail; |
211 | const auto mask_tail = jcp.oc_without_padding % jcp.ch_block; |
212 | const auto dst_layout_nxc = is_dst_layout_nxc(); |
213 | const auto ch_blk = jcp.ch_block; |
214 | const auto ocb_stride |
215 | = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk; |
216 | const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; |
217 | const bool mask_tail_blocked_layout |
218 | = jcp.oc_without_padding % jcp.ch_block && !dst_layout_nxc; |
219 | iterate(ur_ch_blocks, ur_w, mask_tail, |
220 | [&](int ch, int ow, int mask_flag) { |
221 | const size_t aux_output_l_off = jcp.typesize_out |
222 | * (ch * ocb_stride + ow * ow_stride); |
223 | const auto vmm_idx = get_acc_reg_idx(ch * ur_w + ow); |
224 | vmm_idxs.emplace(vmm_idx); |
225 | |
226 | rhs_arg_params_tail.vmm_idx_to_out_reg.emplace( |
227 | vmm_idx, reg_output); |
228 | rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace( |
229 | vmm_idx, aux_output_l_off); |
230 | if (mask_flag) |
231 | rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx); |
232 | }); |
233 | rhs_arg_params = rhs_arg_params_tail; |
234 | rhs_arg_params.vmm_tail_idx_.clear(); |
235 | |
236 | Label postops_done; |
237 | if (mask_tail_blocked_layout) { |
238 | Label postops_no_tail; |
239 | mov(reg_tmp, ptr[param1 + GET_OFF(load_work)]); |
240 | cmp(reg_tmp, jcp.nb_ch_blocking * jcp.ch_block); |
241 | jge(postops_no_tail, T_NEAR); |
242 | postops_injector_->compute_vector_range( |
243 | vmm_idxs, rhs_arg_params_tail); |
244 | jmp(postops_done, T_NEAR); |
245 | L(postops_no_tail); |
246 | postops_injector_->compute_vector_range( |
247 | vmm_idxs, rhs_arg_params); |
248 | } else if (last_ch_block_flag) |
249 | postops_injector_->compute_vector_range( |
250 | vmm_idxs, rhs_arg_params_tail); |
251 | else /* if (!last_ch_block_flag) */ |
252 | postops_injector_->compute_vector_range( |
253 | vmm_idxs, rhs_arg_params); |
254 | L(postops_done); |
255 | |
256 | } else { |
257 | iterate(ur_ch_blocks, ur_w, [&](int ch, int ow, int) { |
258 | vmm_idxs.emplace(get_acc_reg_idx(ch * ur_w + ow)); |
259 | }); |
260 | postops_injector_->compute_vector_range(vmm_idxs); |
261 | } |
262 | } |
263 | } |
264 | |
265 | void jit_avx512_dw_conv_fwd_kernel_bf16::store_dst( |
266 | int ur_ch_blocks, int ur_w, bool last_ch_block_flag) { |
267 | |
268 | const auto dst_layout_nxc = is_dst_layout_nxc(); |
269 | const auto ch_blk = jcp.ch_block; |
270 | const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk; |
271 | const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk; |
272 | |
273 | if (jcp.dst_dt == data_type::bf16 && !isa_has_bf16(jcp.isa)) |
274 | bf16_emu_->init_vcvtneps2bf16(); |
275 | |
276 | if (dst_layout_nxc && jcp.dst_dt == data_type::bf16 |
277 | && isa_has_bf16(jcp.isa)) { |
278 | for (int j = 0; j < ur_w; ++j) { |
279 | int n_2bf2ps = (ur_ch_blocks / 2) * 2; |
280 | int ch = 0; |
281 | for (; ch < n_2bf2ps; ch += 2) { |
282 | size_t aux_output_offset |
283 | = (size_t)ch * ocb_stride + j * ow_stride; |
284 | auto addr = ptr[reg_output |
285 | + aux_output_offset * jcp.typesize_out]; |
286 | auto zmm_dst = get_acc_reg(ch * ur_w + j); |
287 | vcvtne2ps2bf16( |
288 | zmm_dst, get_acc_reg((ch + 1) * ur_w + j), zmm_dst); |
289 | bool mask_flag = last_ch_block_flag && ch + 2 == ur_ch_blocks; |
290 | Zmm zmm_dst_msk = mask_flag ? zmm_dst | k_ch_tail_mask_extended |
291 | : zmm_dst; |
292 | vmovdqu16(addr, zmm_dst_msk); |
293 | } |
294 | /* Perform tail write for odd ch sizes */ |
295 | if (ch < ur_ch_blocks) { |
296 | size_t aux_output_offset |
297 | = (size_t)ch * ocb_stride + j * ow_stride; |
298 | auto addr = ptr[reg_output |
299 | + aux_output_offset * jcp.typesize_out]; |
300 | auto zmm_dst = get_acc_reg(ch * ur_w + j); |
301 | auto ymm_dst = Ymm(zmm_dst.getIdx()); |
302 | vcvtneps2bf16(ymm_dst, zmm_dst); |
303 | Ymm ymm_dst_msk |
304 | = last_ch_block_flag ? ymm_dst | ktail_mask : ymm_dst; |
305 | vmovdqu16(addr, ymm_dst_msk); |
306 | } |
307 | } |
308 | } else { |
309 | // also used for case when dst_layout_nxc && dst.dt == f32 |
310 | if (jcp.dst_dt == data_type::f32) { |
311 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
312 | bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1; |
313 | for (int ow = 0; ow < ur_w; ow++) { |
314 | int o_off = ch * ocb_stride + ow * ow_stride; |
315 | Zmm zmm_dst = get_acc_reg(ch * ur_w + ow); |
316 | Zmm zmm_dst_msk |
317 | = mask_flag ? zmm_dst | ktail_mask : zmm_dst; |
318 | vmovups(vmmword[reg_output + o_off * jcp.typesize_out], |
319 | zmm_dst_msk); |
320 | } |
321 | } |
322 | } else if (jcp.dst_dt == data_type::bf16) { |
323 | if (isa_has_bf16(jcp.isa)) { // !dst_layout_nxc() |
324 | assert(jcp.ngroups % jcp.ch_block == 0); |
325 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
326 | int n_2bf2ps = (ur_w / 2) * 2; |
327 | int j = 0; |
328 | for (; j < n_2bf2ps; j += 2) { |
329 | size_t aux_output_offset |
330 | = (size_t)ch * ocb_stride + j * ow_stride; |
331 | auto addr = ptr[reg_output |
332 | + aux_output_offset * jcp.typesize_out]; |
333 | auto zmm_dst = get_acc_reg(ch * ur_w + j); |
334 | vcvtne2ps2bf16(zmm_dst, get_acc_reg(ch * ur_w + j + 1), |
335 | get_acc_reg(ch * ur_w + j)); |
336 | vmovups(addr, zmm_dst); |
337 | } |
338 | /* Perform tail write for odd ur_w sizes */ |
339 | if (j < ur_w) { |
340 | size_t aux_output_offset |
341 | = (size_t)ch * ocb_stride + j * ow_stride; |
342 | auto addr = ptr[reg_output |
343 | + aux_output_offset * jcp.typesize_out]; |
344 | auto zmm_dst = get_acc_reg(ch * ur_w + j); |
345 | auto ymm_dst = Ymm(zmm_dst.getIdx()); |
346 | vcvtneps2bf16(ymm_dst, zmm_dst); |
347 | vmovups(addr, ymm_dst); |
348 | } |
349 | } |
350 | } else { |
351 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
352 | bool mask_flag |
353 | = last_ch_block_flag && ch == ur_ch_blocks - 1; |
354 | for (int ow = 0; ow < ur_w; ow++) { |
355 | int o_off = ch * ocb_stride + ow * ow_stride; |
356 | Zmm zmm_dst = get_acc_reg(ch * ur_w + ow); |
357 | |
358 | /* down-convert f32 output to bf16 */ |
359 | auto ymm_dst = Ymm(zmm_dst.getIdx()); |
360 | bf16_emu_->vcvtneps2bf16(ymm_dst, zmm_dst); |
361 | |
362 | Ymm ymm_dst_msk |
363 | = mask_flag ? ymm_dst | ktail_mask : ymm_dst; |
364 | vmovdqu16(ptr[reg_output + o_off * jcp.typesize_out], |
365 | ymm_dst_msk); |
366 | } |
367 | } |
368 | } |
369 | } else |
370 | assert(!"unsupported destination type" ); |
371 | } |
372 | } |
373 | |
374 | void jit_avx512_dw_conv_fwd_kernel_bf16::compute_loop( |
375 | int ur_w, int ur_ch_blocks, int pad_l, int pad_r) { |
376 | |
377 | // ch_loop currently happen only when data layout is nxc. The strides are |
378 | // calculated for this layout only. |
379 | const size_t wei_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.kh * jcp.kw |
380 | * jcp.ch_block * jcp.typesize_in; |
381 | const size_t inp_ch_stride |
382 | = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_in; |
383 | const size_t out_ch_stride |
384 | = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_out; |
385 | const size_t bias_stride |
386 | = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float); |
387 | |
388 | auto compute = [&](int ur_ch_blocks, bool last_ch_block_flag = false) { |
389 | if (jcp.is_fused_conv) { |
390 | mov(aux_reg_input_buffer_ptr, reg_input_buffer_ptr); |
391 | } else { |
392 | mov(aux_reg_input, reg_input); |
393 | } |
394 | |
395 | mov(aux_reg_kernel, reg_kernel); |
396 | load_src(ur_ch_blocks, ur_w, last_ch_block_flag); |
397 | apply_filter_unrolled( |
398 | ur_ch_blocks, ur_w, pad_l, pad_r, last_ch_block_flag); |
399 | apply_postops(ur_ch_blocks, ur_w, last_ch_block_flag); |
400 | store_dst(ur_ch_blocks, ur_w, last_ch_block_flag); |
401 | }; |
402 | |
403 | const bool masked_ch_block_tail = jcp.oc % jcp.ch_block != 0; |
404 | const bool ch_loop = ur_ch_blocks > jcp.nb_ch_blocking; |
405 | |
406 | push(reg_ch_blocks); |
407 | |
408 | if (ch_loop) { |
409 | Label ch_loop_label, ch_tail_label, skip_ch_tail_label; |
410 | const int nb_ch = jcp.oc / jcp.ch_block; |
411 | const int nb_ch_blocking_tail |
412 | = jcp.nb_ch - utils::rnd_dn(nb_ch, jcp.nb_ch_blocking); |
413 | const int ch_step = jcp.nb_ch_blocking * jcp.ch_block; |
414 | |
415 | push(reg_kernel); |
416 | push(reg_input); |
417 | push(reg_output); |
418 | if (jcp.with_bias) push(reg_bias); |
419 | |
420 | if (nb_ch >= jcp.nb_ch_blocking) { |
421 | if (nb_ch_blocking_tail) { |
422 | cmp(reg_ch_blocks, ch_step); |
423 | jl(ch_tail_label, T_NEAR); |
424 | } |
425 | |
426 | L(ch_loop_label); |
427 | { |
428 | compute(jcp.nb_ch_blocking); |
429 | add(reg_kernel, wei_ch_stride); |
430 | add(reg_input, inp_ch_stride); |
431 | add(reg_output, out_ch_stride); |
432 | if (jcp.with_bias) add(reg_bias, bias_stride); |
433 | sub(reg_ch_blocks, ch_step); |
434 | cmp(reg_ch_blocks, ch_step); |
435 | jge(ch_loop_label, T_NEAR); |
436 | } |
437 | } |
438 | if (nb_ch_blocking_tail) { |
439 | // ch work range [1, jcp.nb_ch_blocking * ch_block) |
440 | L(ch_tail_label); |
441 | cmp(reg_ch_blocks, 0); |
442 | jle(skip_ch_tail_label, T_NEAR); |
443 | compute(nb_ch_blocking_tail, masked_ch_block_tail); |
444 | L(skip_ch_tail_label); |
445 | } |
446 | if (jcp.with_bias) pop(reg_bias); |
447 | pop(reg_output); |
448 | pop(reg_input); |
449 | pop(reg_kernel); |
450 | |
451 | } else { |
452 | compute(ur_ch_blocks, masked_ch_block_tail); |
453 | } |
454 | |
455 | pop(reg_ch_blocks); |
456 | } |
457 | |
458 | void jit_avx512_dw_conv_fwd_kernel_bf16::loop_ow(int ur_ch_blocks) { |
459 | |
460 | int iw = jcp.iw; |
461 | int ow = jcp.ow; |
462 | int kw = jcp.kw; |
463 | int l_pad = jcp.l_pad; |
464 | int ur_w = jcp.ur_w; |
465 | int ur_w_tail = jcp.ur_w_tail; |
466 | int stride_w = jcp.stride_w; |
467 | |
468 | const auto src_layout_nxc = is_src_layout_nxc(); |
469 | const auto dat_c_stride = src_layout_nxc ? jcp.ngroups : jcp.ch_block; |
470 | size_t inp_shift = (size_t)jcp.typesize_in * ur_w * stride_w * dat_c_stride; |
471 | size_t out_shift = (size_t)jcp.typesize_out * ur_w * dat_c_stride; |
472 | |
473 | int inp_shift_pad |
474 | = jcp.typesize_in * (ur_w * stride_w - l_pad) * dat_c_stride; |
475 | |
476 | int r_pad = nstl::max(0, jcp.r_pad); |
477 | int n_oi = ow / ur_w; |
478 | int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w, |
479 | calculate_extended_filter_size(kw, jcp.dilate_w)); |
480 | |
481 | assert(jcp.nb_ow <= 1); |
482 | |
483 | if (r_pad1 > 0) n_oi--; |
484 | xor_(reg_oi, reg_oi); |
485 | if (ow == ur_w) { |
486 | compute_loop(ur_w, ur_ch_blocks, l_pad, r_pad); |
487 | } else { |
488 | if (n_oi == 0) { |
489 | compute_loop(ur_w, ur_ch_blocks, l_pad, r_pad1); |
490 | add(reg_input, inp_shift_pad); |
491 | add(reg_output, out_shift); |
492 | if (ur_w_tail != 0) { |
493 | compute_loop(ur_w_tail, ur_ch_blocks, 0, r_pad); |
494 | } |
495 | } else { |
496 | if (l_pad > 0) { |
497 | compute_loop(ur_w, ur_ch_blocks, l_pad, 0); |
498 | add(reg_input, inp_shift_pad); |
499 | add(reg_output, out_shift); |
500 | inc(reg_oi); |
501 | } |
502 | if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { |
503 | Label ow_loop_label; |
504 | L(ow_loop_label); |
505 | { |
506 | compute_loop(ur_w, ur_ch_blocks, 0, 0); |
507 | add(reg_input, inp_shift); |
508 | add(reg_output, out_shift); |
509 | |
510 | inc(reg_oi); |
511 | cmp(reg_oi, n_oi); |
512 | jl(ow_loop_label, T_NEAR); |
513 | } |
514 | } |
515 | if (r_pad1 > 0) { |
516 | compute_loop(ur_w, ur_ch_blocks, 0, r_pad1); |
517 | add(reg_input, inp_shift); |
518 | add(reg_output, out_shift); |
519 | } |
520 | if (ur_w_tail != 0) { |
521 | compute_loop(ur_w_tail, ur_ch_blocks, 0, r_pad); |
522 | } |
523 | } |
524 | } |
525 | } |
526 | |
527 | void jit_avx512_dw_conv_fwd_kernel_bf16::generate() { |
528 | this->preamble(); |
529 | |
530 | assert(mayiuse(avx512_core)); |
531 | if (jcp.is_fused_conv) { |
532 | mov(reg_input_buffer_ptr, ptr[this->param1 + GET_OFF(src)]); |
533 | /* In case of fused depthwise convolution, `param.src` is not a pointer |
534 | to input, instead it points to a buffer containing pointers to |
535 | consecutive rows of input in format Cwc with blocking nb_ch_blocking. |
536 | Example: [ptr_to_inp_row0, ptr_to_inp_row1, ptr_to_inp_row2]. |
537 | Traverse the data as |
538 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
539 | ... process row0 ... |
540 | add(reg_input_buffer_ptr, sizeof(void*)) |
541 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
542 | ... process row1 ... |
543 | add(reg_input_buffer_ptr, sizeof(void*)) |
544 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
545 | ... process row2 ... |
546 | */ |
547 | xor_(reg_iw_offset, reg_iw_offset); |
548 | } else { |
549 | mov(reg_input, ptr[this->param1 + GET_OFF(src)]); |
550 | } |
551 | mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); |
552 | mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); |
553 | if (jcp.with_bias) mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); |
554 | mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); |
555 | mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(load_work)]); |
556 | |
557 | Label ch_blocks_tail_label; |
558 | Label exit_label; |
559 | |
560 | const int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; |
561 | const auto oc_tail = jcp.oc_without_padding % jcp.ch_block; |
562 | if (oc_tail != 0) { |
563 | // Note: is_src_layout_nxc() == true, otherwise channels are padded |
564 | // Prepare masks for tailing |
565 | const int oc_tail_shift |
566 | = jcp.ch_block - jcp.oc_without_padding % jcp.ch_block; |
567 | static constexpr auto zmm_16b_mask = ((1 << 16) - 1); |
568 | |
569 | // To account for special store optimization, where two oc_blocks are |
570 | // combined with one single write, extend the mask for 32 bits |
571 | // (i.e. 32 bfloat16 elements) |
572 | const bool need_extended_mask = jcp.dst_dt == data_type::bf16 |
573 | && isa_has_bf16(jcp.isa) && jcp.nb_ch_blocking > 1; |
574 | if (need_extended_mask) |
575 | kxnord(k_ch_tail_mask_extended, k_ch_tail_mask_extended, |
576 | k_ch_tail_mask_extended); |
577 | |
578 | Label done; |
579 | mov(reg_tail, ptr[param1 + GET_OFF(load_work)]); |
580 | cmp(reg_tail, jcp.nb_ch_blocking * jcp.ch_block); |
581 | je(done, T_NEAR); |
582 | Reg32 reg_tail_32 = reg_tail.cvt32(); |
583 | mov(reg_tail_32, zmm_16b_mask >> oc_tail_shift); |
584 | kmovw(k_oc_tail_mask, reg_tail_32); |
585 | if (need_extended_mask) { |
586 | auto zmm_32b_mask = (1 << (oc_tail + jcp.ch_block)) - 1; |
587 | mov(reg_tail_32, zmm_32b_mask); |
588 | kmovd(k_ch_tail_mask_extended, reg_tail_32); |
589 | } |
590 | L(done); |
591 | } |
592 | |
593 | if (is_src_layout_nxc()) { |
594 | loop_ow(jcp.nb_ch); |
595 | } else { |
596 | cmp(reg_ch_blocks, (jcp.nb_ch_blocking - 1) * jcp.ch_block); |
597 | jle(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); |
598 | |
599 | loop_ow(jcp.nb_ch_blocking); // channel main loop |
600 | |
601 | if (ch_blocks_tail) { |
602 | jmp(exit_label, T_NEAR); |
603 | L(ch_blocks_tail_label); |
604 | |
605 | loop_ow(ch_blocks_tail); // channel tail loop |
606 | } |
607 | |
608 | L(exit_label); |
609 | } |
610 | |
611 | postamble(); |
612 | |
613 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
614 | } |
615 | |
616 | inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::load_ddst( |
617 | int ur_ch_blocks, int ur_str_w) { |
618 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
619 | for (int w = 0; w < ur_str_w; w++) { |
620 | Zmm zmm_acc = get_acc_reg(ch * ur_str_w + w); |
621 | uni_vpxor(zmm_acc, zmm_acc, zmm_acc); |
622 | } |
623 | } |
624 | } |
625 | |
626 | inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::apply_filter( |
627 | int ur_ch_blocks, int ur_str_w, bool last_ch_block_flag) { |
628 | int kw = jcp.kw; |
629 | int kh = jcp.kh; |
630 | int ow = jcp.ow; |
631 | int oh = jcp.oh; |
632 | |
633 | int ch_blk = jcp.ch_block; |
634 | int stride_h = jcp.stride_h; |
635 | int stride_w = jcp.stride_w; |
636 | |
637 | const bool ddst_layout_nxc = is_ddst_layout_nxc(); |
638 | const size_t ch_block_step = ch_blk * (ddst_layout_nxc ? 1 : oh * ow); |
639 | const size_t sp_step = ddst_layout_nxc ? jcp.ngroups : ch_blk; |
640 | |
641 | Label iter_exit_label; |
642 | |
643 | cmp(reg_kh, 0); |
644 | je(iter_exit_label, T_NEAR); |
645 | |
646 | cmp(reg_kw, 0); |
647 | je(iter_exit_label, T_NEAR); |
648 | |
649 | mov(iter_kh, reg_kh); |
650 | Label kh_label; |
651 | L(kh_label); |
652 | { |
653 | mov(aux1_reg_ddst, aux_reg_ddst); |
654 | mov(aux1_reg_kernel, aux_reg_kernel); |
655 | |
656 | mov(iter_kw, reg_kw); |
657 | Label kw_label; |
658 | L(kw_label); |
659 | { |
660 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
661 | const bool mask_flag |
662 | = last_ch_block_flag && ch == ur_ch_blocks - 1; |
663 | int ker_off = ch * kh * kw * ch_blk; |
664 | Zmm mm_zmm_ker // mm: maybe masked |
665 | = mask_flag ? zmm_ker_reg | k_ch_tail_mask | T_z |
666 | : zmm_ker_reg; |
667 | vpmovzxwd(mm_zmm_ker, |
668 | ptr[aux1_reg_kernel + ker_off * jcp.typesize_in]); |
669 | |
670 | for (int w = 0; w < ur_str_w; w++) { |
671 | size_t sp_offset = w * sp_step; |
672 | size_t ch_offset = ch * ch_block_step; |
673 | size_t ddst_off = sp_offset + ch_offset; |
674 | Zmm zmm_acc = get_acc_reg(ch * ur_str_w + w); |
675 | Zmm mm_zmm_dst // mm: maybe masked |
676 | = mask_flag ? zmm_dst_reg | k_ch_tail_mask | T_z |
677 | : zmm_dst_reg; |
678 | vpmovzxwd(mm_zmm_dst, |
679 | ptr[aux1_reg_ddst + ddst_off * jcp.typesize_in]); |
680 | |
681 | if (isa_has_bf16(jcp.isa)) |
682 | vdpbf16ps(zmm_acc, mm_zmm_ker, mm_zmm_dst); |
683 | else |
684 | bf16_emu_->vdpbf16ps(zmm_acc, mm_zmm_dst, mm_zmm_ker); |
685 | } |
686 | } |
687 | |
688 | add(aux1_reg_kernel, ch_blk * stride_w * jcp.typesize_in); |
689 | sub(aux1_reg_ddst, sp_step * jcp.typesize_in); |
690 | |
691 | sub(iter_kw, stride_w); |
692 | cmp(iter_kw, 0); |
693 | jg(kw_label, T_NEAR); |
694 | } |
695 | |
696 | add(aux_reg_kernel, kw * ch_blk * stride_h * jcp.typesize_in); |
697 | sub(aux_reg_ddst, ow * sp_step * jcp.typesize_in); |
698 | |
699 | sub(iter_kh, stride_h); |
700 | cmp(iter_kh, 0); |
701 | jg(kh_label, T_NEAR); |
702 | } |
703 | |
704 | L(iter_exit_label); |
705 | } |
706 | |
707 | inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::store_dsrc( |
708 | int ur_ch_blocks, int ur_str_w, bool last_ch_block_flag) { |
709 | int ch_blk = jcp.ch_block; |
710 | int iw = jcp.iw; |
711 | int ih = jcp.ih; |
712 | int stride_w = jcp.stride_w; |
713 | |
714 | const auto dsrc_layout_nxc = is_dsrc_layout_nxc(); |
715 | const size_t ch_block_step = ch_blk * (dsrc_layout_nxc ? 1 : ih * iw); |
716 | const size_t sp_step = dsrc_layout_nxc ? jcp.ngroups : ch_blk; |
717 | |
718 | if (jcp.dsrc_dt == data_type::bf16 && !isa_has_bf16(jcp.isa)) |
719 | bf16_emu_->init_vcvtneps2bf16(); |
720 | |
721 | for (int ch = 0; ch < ur_ch_blocks; ch++) { |
722 | const bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1; |
723 | for (int w = 0; w < ur_str_w; w++) { |
724 | size_t sp_offset = w * stride_w * sp_step; |
725 | size_t ch_offset = ch * ch_block_step; |
726 | int dsrc_off = sp_offset + ch_offset; |
727 | auto zmm_dsrc = get_acc_reg(ch * ur_str_w + w); |
728 | Zmm mm_zmm_dsrc // mm: maybe masked |
729 | = mask_flag ? zmm_dsrc | k_ch_tail_mask : zmm_dsrc; |
730 | |
731 | if (jcp.dsrc_dt == data_type::f32) { |
732 | uni_vmovups(ptr[reg_dsrc + dsrc_off * jcp.typesize_out], |
733 | mm_zmm_dsrc); |
734 | } else if (jcp.dsrc_dt == data_type::bf16) { |
735 | auto ymm_dsrc = Ymm(zmm_dsrc.getIdx()); |
736 | Ymm mm_ymm_dsrc // mm: maybe masked |
737 | = mask_flag ? ymm_dsrc | k_ch_tail_mask : ymm_dsrc; |
738 | |
739 | if (isa_has_bf16(jcp.isa)) |
740 | vcvtneps2bf16(mm_ymm_dsrc, mm_zmm_dsrc); |
741 | else |
742 | bf16_emu_->vcvtneps2bf16(mm_ymm_dsrc, mm_zmm_dsrc); |
743 | vmovdqu16(ptr[reg_dsrc + dsrc_off * jcp.typesize_out], |
744 | mm_ymm_dsrc); |
745 | } |
746 | } |
747 | } |
748 | /* Note: current 'store_dsrc' is limited to storing 'ymm' output. This is |
749 | * because of the current implementation approach that calculates convolution as |
750 | * a strided backward-pass. To increase store throughput by writing 'zmm' |
751 | * registers, changes are needed in both JIT-kernel and Driver code. */ |
752 | } |
753 | |
754 | inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::ch_loop_body( |
755 | int ur_ch_blocks, int unroll_w) { |
756 | |
757 | auto call_compute_body |
758 | = [&](int ur_ch_blocks, int unroll_w, bool is_last_ch = false) { |
759 | mov(aux_reg_ddst, reg_ddst); |
760 | mov(aux_reg_kernel, reg_kernel); |
761 | |
762 | load_ddst(ur_ch_blocks, unroll_w); |
763 | apply_filter(ur_ch_blocks, unroll_w, is_last_ch); |
764 | store_dsrc(ur_ch_blocks, unroll_w, is_last_ch); |
765 | }; |
766 | |
767 | const bool write_ch_loop = ur_ch_blocks > jcp.nb_ch_blocking; |
768 | if (write_ch_loop) { |
769 | assert(is_ddst_layout_nxc() && is_dsrc_layout_nxc()); |
770 | |
771 | Label ch_loop_label, ch_tail_label, skip_ch_tail_label; |
772 | const int nb_oc = jcp.oc / jcp.ch_block; |
773 | const int ch_block_tail |
774 | = jcp.nb_ch - (utils::rnd_dn(nb_oc, jcp.nb_ch_blocking)); |
775 | const int ch_step = jcp.nb_ch_blocking * jcp.ch_block; |
776 | |
777 | const size_t wei_ch_stride |
778 | = (size_t)jcp.nb_ch_blocking * jcp.kh * jcp.kw * jcp.ch_block; |
779 | const size_t data_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.ch_block; |
780 | |
781 | mov(aux_reg_ch_blocks, reg_ch_blocks); |
782 | push(reg_dsrc); |
783 | push(reg_ddst); |
784 | push(reg_kernel); |
785 | |
786 | if (nb_oc >= jcp.nb_ch_blocking) { |
787 | if (ch_block_tail) { |
788 | cmp(aux_reg_ch_blocks, jcp.nb_ch_blocking * jcp.ch_block); |
789 | jl(ch_tail_label, T_NEAR); |
790 | } |
791 | |
792 | L(ch_loop_label); |
793 | { |
794 | call_compute_body(jcp.nb_ch_blocking, unroll_w); |
795 | |
796 | add(reg_kernel, wei_ch_stride * jcp.typesize_in); |
797 | add(reg_dsrc, data_ch_stride * jcp.typesize_out); |
798 | add(reg_ddst, data_ch_stride * jcp.typesize_in); |
799 | |
800 | sub(aux_reg_ch_blocks, ch_step); |
801 | cmp(aux_reg_ch_blocks, ch_step); |
802 | jge(ch_loop_label, T_NEAR); |
803 | } |
804 | } |
805 | |
806 | if (ch_block_tail) { |
807 | // ch work range [1, jcp.nb_ch_blocking * ch_block) |
808 | L(ch_tail_label); |
809 | cmp(aux_reg_ch_blocks, 0); |
810 | jle(skip_ch_tail_label, T_NEAR); |
811 | call_compute_body(ch_block_tail, unroll_w, jcp.ch_tail); |
812 | L(skip_ch_tail_label); |
813 | } |
814 | |
815 | pop(reg_kernel); |
816 | pop(reg_ddst); |
817 | pop(reg_dsrc); |
818 | |
819 | } else { |
820 | call_compute_body(ur_ch_blocks, unroll_w, jcp.ch_tail); |
821 | } |
822 | } |
823 | |
824 | inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::unroll_width_body( |
825 | int ur_ch_blocks) { |
826 | |
827 | auto unroll_width_loop = [&](int unroll_w) { |
828 | Label unroll_w_label, skip_compute_label; |
829 | L(unroll_w_label); |
830 | { |
831 | const size_t ch_step = unroll_w |
832 | * (is_ddst_layout_nxc() ? jcp.ngroups : jcp.ch_block); |
833 | cmp(reg_ur_str_w, unroll_w); |
834 | jl(skip_compute_label, T_NEAR); |
835 | |
836 | ch_loop_body(ur_ch_blocks, unroll_w); |
837 | |
838 | add(reg_dsrc, jcp.typesize_out * jcp.stride_w * ch_step); |
839 | add(reg_ddst, jcp.typesize_in * ch_step); |
840 | |
841 | sub(reg_ur_str_w, unroll_w); |
842 | jmp(unroll_w_label); |
843 | } |
844 | L(skip_compute_label); |
845 | }; |
846 | |
847 | unroll_width_loop(jcp.ur_w); |
848 | |
849 | unroll_width_loop(1); |
850 | } |
851 | |
852 | void jit_avx512_dw_conv_bwd_data_kernel_bf16::generate() { |
853 | assert(is_dsrc_layout_nxc() == is_ddst_layout_nxc()); |
854 | |
855 | preamble(); |
856 | mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); |
857 | mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); |
858 | mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); |
859 | mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); |
860 | mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); |
861 | mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); |
862 | mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]); |
863 | |
864 | if (is_dsrc_layout_nxc()) { |
865 | if (jcp.ch_tail) { |
866 | Label masking_done; |
867 | const size_t channel_step = jcp.nb_ch_blocking * jcp.ch_block; |
868 | kxnorw(k_ch_tail_mask, k_ch_tail_mask, |
869 | k_ch_tail_mask); // dummy mask all 1's |
870 | cmp(reg_ch_blocks, channel_step); |
871 | je(masking_done, T_NEAR); |
872 | // Prepare masks for tail |
873 | Reg32 reg_tmp_32 = reg_tmp.cvt32(); |
874 | mov(reg_tmp_32, (1 << jcp.ch_tail) - 1); |
875 | kmovw(k_ch_tail_mask, reg_tmp_32); |
876 | L(masking_done); |
877 | } |
878 | |
879 | unroll_width_body(jcp.nb_ch); |
880 | } else { |
881 | auto ch_blocks_loop = [&](int ch_blocks) { |
882 | Label skip_loop_label; |
883 | cmp(reg_ch_blocks, ch_blocks * jcp.ch_block); |
884 | jl(skip_loop_label, T_NEAR); |
885 | unroll_width_body(ch_blocks); |
886 | L(skip_loop_label); |
887 | }; |
888 | |
889 | ch_blocks_loop(jcp.nb_ch_blocking); |
890 | |
891 | int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; |
892 | if (ch_blocks_tail) { ch_blocks_loop(ch_blocks_tail); } |
893 | } |
894 | postamble(); |
895 | } |
896 | #undef GET_OFF |
897 | |
898 | #define GET_OFF(field) offsetof(jit_dw_conv_call_s, field) |
899 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::zero_filter() { |
900 | for (int i = 0; i < jcp.kw; ++i) { |
901 | Zmm zmm_acc = get_acc_reg(i); |
902 | uni_vpxor(zmm_acc, zmm_acc, zmm_acc); |
903 | } |
904 | } |
905 | |
906 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::load_filter(bool is_last_ch) { |
907 | for (int i = 0; i < jcp.kw; ++i) { |
908 | int off_filter = i * jcp.ch_block; |
909 | Zmm zmm_acc = get_acc_reg(i); |
910 | Zmm m_zmm_acc = is_last_ch ? zmm_acc | k_ch_tail_mask | T_z : zmm_acc; |
911 | vmovups(m_zmm_acc, |
912 | vmmword[reg_tmp_filter + off_filter * jcp.typesize_out]); |
913 | } |
914 | } |
915 | |
916 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::zero_bias() { |
917 | uni_vpxor(zmm_bias_reg, zmm_bias_reg, zmm_bias_reg); |
918 | } |
919 | |
920 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::load_bias(bool is_last_ch) { |
921 | Zmm m_zmm_bias_reg |
922 | = is_last_ch ? zmm_bias_reg | k_ch_tail_mask | T_z : zmm_bias_reg; |
923 | vmovups(m_zmm_bias_reg, vmmword[reg_bias_baddr]); |
924 | } |
925 | |
926 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_step_unroll( |
927 | int unroll_w, int l_pad, int pad_offset, int ow_block, |
928 | bool is_last_ch) { |
929 | |
930 | const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
931 | const int iw_block = ow_block * jcp.stride_w; |
932 | const int right_border = jcp.iw - iw_block; |
933 | const int r_pad = jcp.r_pad; |
934 | |
935 | const int cascade_input = nstl::min(jcp.stride_w, jcp.kw); |
936 | |
937 | /* preamble count for number of cascaded LOAD + FMA operation */ |
938 | const int input_overlap = nstl::max(jcp.kw - l_pad, 0); |
939 | const bool is_last_block = (unroll_w + ow_block == jcp.ow); |
940 | |
941 | /* LOAD initial input registers, then cascade LOADs and FMAs*/ |
942 | for (int i_ur = 0; i_ur < unroll_w; ++i_ur) { |
943 | size_t off_output |
944 | = static_cast<size_t>(i_ur * ch_step * jcp.typesize_in); |
945 | Zmm m_zmm_out_reg |
946 | = is_last_ch ? zmm_out_reg | k_ch_tail_mask | T_z : zmm_out_reg; |
947 | vpmovzxwd(m_zmm_out_reg, ptr[reg_tmp_output + off_output]); |
948 | if (i_ur == 0) { |
949 | for (int c = 0; c < input_overlap; ++c) { |
950 | int input_sp = c - pad_offset; |
951 | if (input_sp < 0 && unroll_w == jcp.ow) continue; |
952 | |
953 | const bool over_steps_bdry = true && is_last_block |
954 | && (c - pad_offset + r_pad > right_border); |
955 | if (over_steps_bdry) continue; |
956 | |
957 | size_t input_offset = static_cast<size_t>( |
958 | input_sp * ch_step * jcp.typesize_in); |
959 | Zmm zmm_input = get_input_reg(c); |
960 | Zmm m_zmm_input = is_last_ch ? zmm_input | k_ch_tail_mask | T_z |
961 | : zmm_input; |
962 | vpmovzxwd(m_zmm_input, ptr[reg_tmp_input + input_offset]); |
963 | } |
964 | } else { |
965 | for (int c = 0; c < cascade_input; ++c) { |
966 | int overlap = (i_ur - 1) * jcp.stride_w + input_overlap; |
967 | int input_sp = overlap + c - pad_offset; |
968 | if (input_sp < 0 || overlap + c + l_pad > right_border) |
969 | continue; |
970 | |
971 | const bool over_steps_bdry = true && is_last_block |
972 | && (overlap + c - pad_offset + r_pad > right_border); |
973 | if (over_steps_bdry) continue; |
974 | |
975 | size_t input_offset = static_cast<size_t>( |
976 | input_sp * ch_step * jcp.typesize_in); |
977 | Zmm zmm_input = get_input_reg(overlap + c); |
978 | Zmm m_zmm_input = is_last_ch ? zmm_input | k_ch_tail_mask | T_z |
979 | : zmm_input; |
980 | vpmovzxwd(m_zmm_input, ptr[reg_tmp_input + input_offset]); |
981 | } |
982 | } |
983 | |
984 | for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) { |
985 | int io_overlap = i_kw + (i_ur * jcp.stride_w); |
986 | |
987 | /* Don't apply FMAs that fall into the padded region */ |
988 | if (io_overlap - l_pad < 0 |
989 | || io_overlap - jcp.l_pad >= right_border) |
990 | continue; |
991 | |
992 | const bool over_steps_bdry = true && is_last_block |
993 | && (io_overlap - jcp.l_pad + jcp.r_pad > right_border); |
994 | if (over_steps_bdry) continue; |
995 | |
996 | Zmm zmm_input = get_input_reg(io_overlap - l_pad); |
997 | Zmm zmm_acc = get_acc_reg(i_kw); |
998 | if (isa_has_bf16(jcp.isa)) |
999 | vdpbf16ps(zmm_acc, zmm_input, zmm_out_reg); |
1000 | else |
1001 | bf16_emu_->vdpbf16ps(zmm_acc, zmm_input, zmm_out_reg); |
1002 | } |
1003 | } |
1004 | } |
1005 | |
1006 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_bias_step_unroll( |
1007 | const int unroll_w, bool is_last_ch) { |
1008 | |
1009 | const int ch_step = is_ddst_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
1010 | for (int i = 0; i < unroll_w; ++i) { |
1011 | size_t off_output = static_cast<size_t>(i * ch_step * jcp.typesize_in); |
1012 | /* bf16 output data requires conversion to f32 */ |
1013 | Zmm m_zmm_out_reg |
1014 | = is_last_ch ? zmm_out_reg | k_ch_tail_mask | T_z : zmm_out_reg; |
1015 | vpmovzxwd(m_zmm_out_reg, ptr[reg_tmp_output + off_output]); |
1016 | vpslld(m_zmm_out_reg, m_zmm_out_reg, 0x10); |
1017 | vaddps(zmm_bias_reg, zmm_bias_reg, m_zmm_out_reg); |
1018 | } |
1019 | } |
1020 | |
1021 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::store_filter(bool is_last_ch) { |
1022 | |
1023 | /* bf16: all data is stored as f32. Down-convert to bf16 happens at the |
1024 | * reduction phase. */ |
1025 | for (int i = 0; i < jcp.kw; ++i) { |
1026 | int off_filter = i * jcp.ch_block; |
1027 | Zmm zmm_acc = get_acc_reg(i); |
1028 | Zmm m_zmm_acc = is_last_ch ? zmm_acc | k_ch_tail_mask : zmm_acc; |
1029 | vmovups(vmmword[reg_tmp_filter + off_filter * jcp.typesize_out], |
1030 | m_zmm_acc); |
1031 | } |
1032 | } |
1033 | |
1034 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::store_bias(bool is_last_ch) { |
1035 | Zmm m_zmm_bias_reg |
1036 | = is_last_ch ? zmm_bias_reg | k_ch_tail_mask : zmm_bias_reg; |
1037 | vmovups(vmmword[reg_bias_baddr], m_zmm_bias_reg); |
1038 | } |
1039 | |
1040 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_spatial_loop_bias( |
1041 | bool is_last_ch) { |
1042 | Label oh_label; |
1043 | Label ow_blk_label; |
1044 | |
1045 | const int unroll_w = nstl::min(max_unroll_w_, jcp.ow); |
1046 | const int unroll_w_trips = jcp.ow / unroll_w; |
1047 | const int tail_w = jcp.ow > max_unroll_w_ ? jcp.ow % max_unroll_w_ : 0; |
1048 | |
1049 | const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
1050 | const size_t ch_offset = ch_step * jcp.typesize_in; |
1051 | |
1052 | mov(reg_oh, ptr[this->param1 + GET_OFF(oh_index)]); |
1053 | mov(reg_oh_worksize, ptr[this->param1 + GET_OFF(oh_count)]); |
1054 | |
1055 | mov(reg_tmp_output, reg_output_baddr); |
1056 | L(oh_label); |
1057 | { |
1058 | |
1059 | mov(reg_iter_ow_blk, unroll_w_trips); |
1060 | L(ow_blk_label); |
1061 | { |
1062 | compute_bias_step_unroll(unroll_w, is_last_ch); |
1063 | add(reg_tmp_output, unroll_w * ch_offset); |
1064 | |
1065 | dec(reg_iter_ow_blk); |
1066 | cmp(reg_iter_ow_blk, 0); |
1067 | jg(ow_blk_label, T_NEAR); |
1068 | } |
1069 | |
1070 | if (tail_w > 0) { |
1071 | compute_bias_step_unroll(tail_w, is_last_ch); |
1072 | add(reg_tmp_output, tail_w * ch_offset); |
1073 | } |
1074 | |
1075 | inc(reg_oh); |
1076 | cmp(reg_oh, reg_oh_worksize); |
1077 | jl(oh_label, T_NEAR); |
1078 | } |
1079 | } |
1080 | |
1081 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16:: |
1082 | compute_single_ch_block_bias() { |
1083 | |
1084 | auto write_compute_bias = [&](bool masked_ch_tail) { |
1085 | Label skip_load_bias; |
1086 | |
1087 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1088 | and_(reg_exec_flags, FLAG_ZERO_BIAS); |
1089 | test(reg_exec_flags, reg_exec_flags); |
1090 | jne(skip_load_bias); |
1091 | |
1092 | load_bias(masked_ch_tail); |
1093 | |
1094 | L(skip_load_bias); |
1095 | compute_spatial_loop_bias(masked_ch_tail); |
1096 | |
1097 | store_bias(masked_ch_tail); |
1098 | }; |
1099 | |
1100 | Label skip_masked_bias_label, done_bias_label; |
1101 | |
1102 | zero_bias(); |
1103 | |
1104 | bool do_bias_ch_tail = jcp.ch_tail > 0; |
1105 | if (do_bias_ch_tail) { |
1106 | // test last channel |
1107 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1108 | and_(reg_exec_flags, FLAG_OC_LAST); |
1109 | test(reg_exec_flags, reg_exec_flags); |
1110 | jz(skip_masked_bias_label, T_NEAR); |
1111 | |
1112 | write_compute_bias(true); |
1113 | |
1114 | jmp(done_bias_label, T_NEAR); |
1115 | L(skip_masked_bias_label); |
1116 | } |
1117 | |
1118 | write_compute_bias(false); |
1119 | |
1120 | L(done_bias_label); |
1121 | } |
1122 | |
1123 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ch_loop_bias( |
1124 | bool do_load_bias) { |
1125 | |
1126 | assert(is_ddst_layout_nxc()); |
1127 | |
1128 | auto write_compute_bias = [&](bool masked_ch_tail) { |
1129 | if (do_load_bias) |
1130 | load_bias(masked_ch_tail); |
1131 | else |
1132 | zero_bias(); |
1133 | compute_spatial_loop_bias(masked_ch_tail); |
1134 | store_bias(masked_ch_tail); |
1135 | }; |
1136 | |
1137 | bool masked_ch_tail = jcp.ch_tail > 0; |
1138 | if (jcp.nb_ch > 1) { |
1139 | |
1140 | Label last_ch_block_label, ch_block_done_label; |
1141 | if (masked_ch_tail) { |
1142 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1143 | and_(reg_exec_flags, FLAG_OC_LAST); |
1144 | test(reg_exec_flags, reg_exec_flags); |
1145 | jnz(last_ch_block_label, T_NEAR); |
1146 | } |
1147 | |
1148 | write_compute_bias(false); |
1149 | |
1150 | if (masked_ch_tail) { |
1151 | jmp(ch_block_done_label, T_NEAR); |
1152 | |
1153 | L(last_ch_block_label); |
1154 | write_compute_bias(true); |
1155 | |
1156 | L(ch_block_done_label); |
1157 | } |
1158 | } else { |
1159 | write_compute_bias(masked_ch_tail); |
1160 | } |
1161 | } |
1162 | |
1163 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::deploy_ch_loop_bias() { |
1164 | |
1165 | Label ch_loop_label, zero_bias_label, load_bias_done_label; |
1166 | |
1167 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1168 | and_(reg_exec_flags, FLAG_ZERO_BIAS); |
1169 | test(reg_exec_flags, reg_exec_flags); |
1170 | jne(zero_bias_label, T_NEAR); |
1171 | |
1172 | compute_ch_loop_bias(true); // load_bias |
1173 | jmp(load_bias_done_label, T_NEAR); |
1174 | |
1175 | L(zero_bias_label); |
1176 | compute_ch_loop_bias(false); // zero_bias |
1177 | |
1178 | L(load_bias_done_label); |
1179 | } |
1180 | |
1181 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_bias() { |
1182 | |
1183 | mov(reg_bias_baddr, ptr[this->param1 + GET_OFF(bias)]); |
1184 | |
1185 | if (is_ddst_layout_nxc()) |
1186 | deploy_ch_loop_bias(); |
1187 | else |
1188 | compute_single_ch_block_bias(); |
1189 | } |
1190 | |
1191 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::zero_filter_kh_loop() { |
1192 | |
1193 | const size_t filter_offset_kw = jcp.kw * jcp.ch_block * jcp.typesize_out; |
1194 | const size_t filter_offset_kh = jcp.kh * filter_offset_kw; |
1195 | |
1196 | Label kh_loop_label; |
1197 | |
1198 | mov(reg_kh_aux, jcp.kh); |
1199 | L(kh_loop_label); |
1200 | { |
1201 | store_filter(); |
1202 | |
1203 | add(reg_tmp_filter, filter_offset_kw); |
1204 | dec(reg_kh_aux); |
1205 | cmp(reg_kh_aux, 0); |
1206 | jg(kh_loop_label, T_NEAR); |
1207 | } |
1208 | |
1209 | /* Comeback pointers */ |
1210 | sub(reg_tmp_filter, filter_offset_kh); |
1211 | } |
1212 | |
1213 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::deploy_zero_filter() { |
1214 | |
1215 | Label skip_zeroing_label; |
1216 | |
1217 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1218 | and_(reg_exec_flags, FLAG_ZERO_FILTER); |
1219 | test(reg_exec_flags, reg_exec_flags); |
1220 | je(skip_zeroing_label, T_NEAR); |
1221 | |
1222 | zero_filter(); |
1223 | |
1224 | mov(reg_tmp_filter, reg_filter_baddr); |
1225 | zero_filter_kh_loop(); |
1226 | |
1227 | L(skip_zeroing_label); |
1228 | } |
1229 | |
1230 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_kh_step(int unroll_w, |
1231 | int l_pad, int pad_offset, int ow_block, bool is_last_ch) { |
1232 | |
1233 | const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
1234 | const size_t input_offset = jcp.iw * ch_step * jcp.typesize_in; |
1235 | const size_t filter_offset = jcp.kw * jcp.ch_block * jcp.typesize_out; |
1236 | |
1237 | Label kh_loop_label, skip_loop_label; |
1238 | |
1239 | cmp(reg_kh, 0); |
1240 | je(skip_loop_label, T_NEAR); |
1241 | |
1242 | mov(reg_kh_aux, reg_kh); |
1243 | L(kh_loop_label); |
1244 | { |
1245 | load_filter(); |
1246 | compute_ow_step_unroll( |
1247 | unroll_w, l_pad, pad_offset, ow_block, is_last_ch); |
1248 | store_filter(); |
1249 | |
1250 | add(reg_tmp_filter, filter_offset); |
1251 | add(reg_tmp_input, input_offset); |
1252 | dec(reg_kh_aux); |
1253 | cmp(reg_kh_aux, 0); |
1254 | jg(kh_loop_label, T_NEAR); |
1255 | } |
1256 | |
1257 | /* Comeback pointers */ |
1258 | Label kh_comeback_label; |
1259 | mov(reg_kh_aux, reg_kh); |
1260 | L(kh_comeback_label); |
1261 | { |
1262 | sub(reg_tmp_input, input_offset); |
1263 | sub(reg_tmp_filter, filter_offset); |
1264 | dec(reg_kh_aux); |
1265 | cmp(reg_kh_aux, 0); |
1266 | jg(kh_comeback_label, T_NEAR); |
1267 | } |
1268 | |
1269 | L(skip_loop_label); |
1270 | } |
1271 | |
1272 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ch_loop( |
1273 | int unroll_w, int l_pad, int pad_offset, int ow_block) { |
1274 | |
1275 | const bool masked_ch_tail = is_layout_nxc() && jcp.ch_tail > 0; |
1276 | bool write_channel_loop = is_layout_nxc() && jcp.nb_ch > 1; |
1277 | if (write_channel_loop) { |
1278 | Label last_ch_block_label, ch_block_done_label; |
1279 | if (masked_ch_tail) { |
1280 | mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]); |
1281 | and_(reg_exec_flags, FLAG_OC_LAST); |
1282 | test(reg_exec_flags, reg_exec_flags); |
1283 | jnz(last_ch_block_label, T_NEAR); |
1284 | } |
1285 | |
1286 | compute_kh_step(unroll_w, l_pad, pad_offset, ow_block, false); |
1287 | |
1288 | if (masked_ch_tail) { |
1289 | jmp(ch_block_done_label, T_NEAR); |
1290 | |
1291 | L(last_ch_block_label); |
1292 | compute_kh_step(unroll_w, l_pad, pad_offset, ow_block, true); |
1293 | L(ch_block_done_label); |
1294 | } |
1295 | } else { |
1296 | compute_kh_step(unroll_w, l_pad, pad_offset, ow_block, masked_ch_tail); |
1297 | } |
1298 | } |
1299 | |
1300 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_h_loop( |
1301 | int unroll_w, int l_pad, int pad_offset, int ow_block) { |
1302 | |
1303 | mov(reg_tmp_output, reg_output_baddr); |
1304 | mov(reg_tmp_input, reg_input_baddr); |
1305 | mov(reg_tmp_filter, reg_filter_baddr); |
1306 | |
1307 | const int input_bottom_padding_overlap |
1308 | = div_up(jcp.ih + jcp.t_pad - (jcp.kh - 1), jcp.stride_h); |
1309 | |
1310 | const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
1311 | const size_t input_shift = jcp.typesize_in * jcp.iw * ch_step; |
1312 | const size_t output_shift = jcp.typesize_in * jcp.ow * ch_step; |
1313 | const size_t filter_shift = jcp.typesize_out * jcp.kw * jcp.ch_block; |
1314 | |
1315 | Label loop_begin_label, loop_end_label, common_block_label, |
1316 | top_padding_end_label, bottom_padding_end_label, |
1317 | bottom_padding_label; |
1318 | |
1319 | mov(reg_oh, ptr[this->param1 + GET_OFF(oh_index)]); |
1320 | mov(reg_kh, ptr[this->param1 + GET_OFF(kh_count)]); |
1321 | |
1322 | // replacement for 'os_index_end' |
1323 | mov(reg_oh_worksize, ptr[this->param1 + GET_OFF(oh_count)]); |
1324 | |
1325 | cmp(reg_kh, 0); |
1326 | jle(loop_end_label, T_NEAR); // no iterations along kh |
1327 | cmp(reg_oh, reg_oh_worksize); |
1328 | jge(loop_end_label, T_NEAR); // no iterations along height dimension |
1329 | |
1330 | L(loop_begin_label); |
1331 | |
1332 | compute_ch_loop(unroll_w, l_pad, pad_offset, ow_block); |
1333 | |
1334 | /* Compute 'top' edge */ |
1335 | if (jcp.t_pad > 0) { |
1336 | |
1337 | /* Check if within top padding region */ |
1338 | cmp(reg_oh, div_up(jcp.t_pad, jcp.stride_h)); |
1339 | jge(top_padding_end_label, T_NEAR); |
1340 | |
1341 | /* Increment step counter and adjust filter position */ |
1342 | sub(reg_tmp_filter, filter_shift * jcp.stride_h); |
1343 | add(reg_kh, jcp.stride_h); |
1344 | |
1345 | /* Final number of kernel elements that overlap with input */ |
1346 | const int inp_ker_overlap = nstl::min(jcp.kh, jcp.ih); |
1347 | cmp(reg_kh, inp_ker_overlap); |
1348 | jle(common_block_label, T_NEAR); |
1349 | |
1350 | /* Correct any excess shifts to kernel and input */ |
1351 | if (jcp.t_pad <= jcp.oh * jcp.stride_h) { |
1352 | /* Filter has moved beyond padding (adjust for stride effects) */ |
1353 | if (jcp.t_pad % jcp.stride_h != 0) { |
1354 | int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h; |
1355 | add(reg_tmp_filter, filter_shift * inp_corr); |
1356 | add(reg_tmp_input, input_shift * inp_corr); |
1357 | } |
1358 | } else { |
1359 | /* Filter still overlaps padding (complete reset) */ |
1360 | sub(reg_tmp_filter, |
1361 | (jcp.t_pad - jcp.oh * jcp.stride_h) * filter_shift); |
1362 | } |
1363 | |
1364 | /* Apply correction: reset value of 'reg_kh' to scenario outside of |
1365 | * special cases due to top_padding (i.e. 'min(jcp.kh, jcp.ih)')*/ |
1366 | mov(reg_kh, inp_ker_overlap); |
1367 | jmp(common_block_label); |
1368 | |
1369 | L(top_padding_end_label); |
1370 | } |
1371 | |
1372 | /* Compute 'bottom' edge */ |
1373 | if (jcp.b_pad > 0) { |
1374 | |
1375 | /* Check if within bottom padding region */ |
1376 | cmp(reg_oh, input_bottom_padding_overlap - 1); |
1377 | jl(bottom_padding_end_label, T_NEAR); |
1378 | jg(bottom_padding_label, T_NEAR); |
1379 | |
1380 | /* Execute overlap correction between the filter and the initial |
1381 | * bottom padding region. */ |
1382 | mov(reg_kh, |
1383 | jcp.ih + jcp.t_pad |
1384 | - input_bottom_padding_overlap * jcp.stride_h); |
1385 | jmp(bottom_padding_end_label, T_NEAR); |
1386 | |
1387 | L(bottom_padding_label); |
1388 | sub(reg_kh, jcp.stride_h); |
1389 | cmp(reg_kh, 0); |
1390 | jle(loop_end_label, T_NEAR); |
1391 | |
1392 | L(bottom_padding_end_label); |
1393 | } |
1394 | |
1395 | /* Compute middle block */ |
1396 | add(reg_tmp_input, input_shift * jcp.stride_h); |
1397 | |
1398 | /* Execute common block and loop */ |
1399 | L(common_block_label); |
1400 | add(reg_tmp_output, output_shift); |
1401 | inc(reg_oh); |
1402 | cmp(reg_oh, reg_oh_worksize); |
1403 | jl(loop_begin_label, T_NEAR); |
1404 | |
1405 | L(loop_end_label); |
1406 | } |
1407 | |
1408 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::calculate_w_unrolling( |
1409 | int &unroll_trips, int &unroll_w, int &unroll_w_tail) { |
1410 | |
1411 | const bool do_unroll_w = jcp.ow > max_unroll_w_; |
1412 | if (do_unroll_w) { |
1413 | unroll_w = nstl::min(block_size_, jcp.ow); |
1414 | unroll_trips = jcp.ow / unroll_w; |
1415 | /* calculate tail */ |
1416 | unroll_w_tail = jcp.ow % unroll_w; |
1417 | /* Perform some rebalancing if tail too small*/ |
1418 | if ((unroll_w_tail == 0 && jcp.r_pad != 0) |
1419 | || (jcp.r_pad > 0 && jcp.r_pad >= unroll_w_tail)) { |
1420 | if (unroll_trips > 1) { |
1421 | unroll_w_tail += unroll_w; |
1422 | unroll_trips--; |
1423 | } else { |
1424 | /* Idealy, this case shouldn't happen */ |
1425 | unroll_w_tail += (unroll_w - unroll_w / 2); |
1426 | unroll_w = unroll_w / 2; |
1427 | } |
1428 | } |
1429 | } else { |
1430 | unroll_w_tail = jcp.ow; |
1431 | } |
1432 | } |
1433 | |
1434 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_block_unroll() { |
1435 | |
1436 | Label ow_blk_label; // for compute middle block |
1437 | int pad_offset = 0; |
1438 | int l_pad = jcp.l_pad; |
1439 | int unroll_w_tail = 0; |
1440 | int unroll_w = 0; |
1441 | int unroll_trips = 0; |
1442 | calculate_w_unrolling(unroll_trips, unroll_w, unroll_w_tail); |
1443 | |
1444 | const size_t ch_offset = is_layout_nxc() ? jcp.ngroups : jcp.ch_block; |
1445 | const size_t data_offset |
1446 | = static_cast<size_t>(unroll_w * ch_offset * jcp.typesize_in); |
1447 | |
1448 | if (jcp.with_bias) compute_bias(); |
1449 | |
1450 | /* Pass filter address, then offset for h_padding. */ |
1451 | deploy_zero_filter(); |
1452 | mov(reg_kh_offset, ptr[this->param1 + GET_OFF(filter_pad_off)]); |
1453 | add(reg_filter_baddr, reg_kh_offset); |
1454 | |
1455 | /* compute left padded block */ |
1456 | const bool do_unroll_w = jcp.ow > max_unroll_w_; |
1457 | if (l_pad && do_unroll_w) { |
1458 | compute_h_loop(unroll_w, l_pad, 0, 0); |
1459 | add(reg_output_baddr, data_offset); |
1460 | add(reg_input_baddr, data_offset * jcp.stride_w); |
1461 | unroll_trips--; |
1462 | pad_offset = l_pad; |
1463 | l_pad = 0; |
1464 | } |
1465 | |
1466 | /* Insert loop for 'ow' block when middle block needs to execute more |
1467 | * than once */ |
1468 | bool do_ow_blk_loop = unroll_trips > 1; |
1469 | if (do_ow_blk_loop) { |
1470 | mov(reg_iter_ow_blk, unroll_trips); |
1471 | L(ow_blk_label); |
1472 | } |
1473 | if (unroll_trips > 0) { |
1474 | compute_h_loop(unroll_w, l_pad, pad_offset, 0); |
1475 | add(reg_output_baddr, data_offset); |
1476 | add(reg_input_baddr, data_offset * jcp.stride_w); |
1477 | } |
1478 | if (do_ow_blk_loop) { |
1479 | dec(reg_iter_ow_blk); |
1480 | cmp(reg_iter_ow_blk, 0); |
1481 | jg(ow_blk_label, T_NEAR); |
1482 | } |
1483 | |
1484 | /* compute right padded block */ |
1485 | if (unroll_w_tail) { |
1486 | compute_h_loop( |
1487 | unroll_w_tail, l_pad, pad_offset, jcp.ow - unroll_w_tail); |
1488 | } |
1489 | } |
1490 | |
1491 | void jit_avx512_dw_conv_bwd_weights_kernel_bf16::generate() { |
1492 | assert(is_src_layout_nxc() == is_ddst_layout_nxc()); |
1493 | |
1494 | preamble(); |
1495 | |
1496 | mov(reg_input_baddr, ptr[this->param1 + GET_OFF(input)]); |
1497 | mov(reg_output_baddr, ptr[this->param1 + GET_OFF(output)]); |
1498 | mov(reg_filter_baddr, ptr[this->param1 + GET_OFF(filter)]); |
1499 | |
1500 | bool set_kmask = jcp.ch_tail > 0 && (jcp.with_bias || is_layout_nxc()); |
1501 | if (set_kmask) { |
1502 | // Prepare masks for tail |
1503 | Reg32 reg_tmp_32 = reg_tmp.cvt32(); |
1504 | mov(reg_tmp_32, (1 << jcp.ch_tail) - 1); |
1505 | kmovw(k_ch_tail_mask, reg_tmp_32); |
1506 | } |
1507 | |
1508 | compute_ow_block_unroll(); |
1509 | |
1510 | postamble(); |
1511 | } |
1512 | #undef GET_OFF |
1513 | |
1514 | } // namespace x64 |
1515 | } // namespace cpu |
1516 | } // namespace impl |
1517 | } // namespace dnnl |
1518 | |