1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/bfloat16.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/math_utils.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/platform.hpp" |
26 | #include "cpu/x64/cpu_barrier.hpp" |
27 | |
28 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
29 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
30 | #include "cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp" |
31 | |
32 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | namespace x64 { |
38 | |
39 | using namespace format_tag; |
40 | using namespace dnnl::impl::memory_tracking::names; |
41 | using namespace dnnl::impl::utils; |
42 | using namespace Xbyak; |
43 | |
44 | namespace { |
45 | |
46 | constexpr auto small_spatial = 14; |
47 | |
48 | inline void pick_loop_order(jit_conv_conf_t &jcp) { |
49 | using namespace prop_kind; |
50 | assert(one_of( |
51 | jcp.prop_kind, forward_training, forward_inference, backward_data)); |
52 | auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow; |
53 | auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh; |
54 | |
55 | if (utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
56 | format_tag::nwc) |
57 | && jcp.ngroups > 1 && jcp.oc < 16) { |
58 | jcp.loop_order = loop_nhwcg; |
59 | } else if (jcp.prop_kind == backward_data) { |
60 | // ow-threading is currently implemented for forward only |
61 | // TODO: single code for fwd and bwd after ow-thr for bwd |
62 | // meaningless switch was removed |
63 | if (jcp.ndims < 5) |
64 | jcp.loop_order = (w <= small_spatial && h <= small_spatial) |
65 | ? loop_cwgn |
66 | : loop_gncw; |
67 | else |
68 | jcp.loop_order = (w <= small_spatial && h <= small_spatial) |
69 | ? loop_cgn |
70 | : loop_gnc; |
71 | } else { |
72 | jcp.loop_order = (w <= small_spatial && h <= small_spatial) ? loop_cwgn |
73 | : loop_gncw; |
74 | } |
75 | } |
76 | inline bool is_ow_threading_available(const jit_conv_conf_t &jcp) { |
77 | /*is 1D conv */ |
78 | return (jcp.id == 1 && jcp.ih == 1 && jcp.kd == 1 && jcp.kh == 1); |
79 | } |
80 | inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) { |
81 | return (jcp.nb_ow > 1); |
82 | } |
83 | inline bool is_iw_threading_available(const jit_conv_conf_t &jcp) { |
84 | return one_of(jcp.ndims, 3, 4); |
85 | } |
86 | inline bool is_iw_threading_on(const jit_conv_conf_t &jcp) { |
87 | return (jcp.nb_iw > 1); |
88 | } |
89 | inline bool is_1stconv(const jit_conv_conf_t &jcp) { |
90 | const bool no_big_offt = nstl::max<size_t>(jcp.ic, jcp.oc) |
91 | * nstl::max(jcp.typesize_in, jcp.typesize_out) * jcp.id |
92 | * jcp.ih * jcp.iw |
93 | < INT_MAX; |
94 | return jcp.ic < 16 && jcp.ngroups == 1 && no_big_offt; |
95 | } |
96 | } // namespace |
97 | |
98 | template <typename Vmm> |
99 | _jit_avx512_core_bf16_fwd_kernel<Vmm>::_jit_avx512_core_bf16_fwd_kernel( |
100 | const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, |
101 | const memory_desc_t &dst_md) |
102 | : jit_generator(jit_name(), nullptr, ker_code_size, true, avx512_core_bf16) |
103 | , jcp(ajcp) |
104 | , attr_(attr) { |
105 | if (jcp.with_eltwise || jcp.with_binary) { |
106 | using namespace binary_injector; |
107 | static constexpr bool preserve_gpr = true; |
108 | static constexpr bool preserve_vmm = false; |
109 | static constexpr size_t helper_vmm_idx = 31; |
110 | const size_t oc_block_tail = jcp.oc_block % isa_simd_width_; |
111 | const size_t tail_size = oc_block_tail |
112 | ? oc_block_tail |
113 | : jcp.oc_without_padding % isa_simd_width_; |
114 | static constexpr bool use_exact_tail_scalar_bcast = true; |
115 | |
116 | const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, |
117 | r14, r15, r12, preserve_gpr, preserve_vmm, |
118 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
119 | memory_desc_wrapper(dst_md), tail_size, postops_mask, |
120 | use_exact_tail_scalar_bcast}; |
121 | const static_params_t static_params { |
122 | this->param1, rhs_arg_static_params}; |
123 | |
124 | postops_injector_ = utils::make_unique< |
125 | injector::jit_uni_postops_injector_t<avx512_core, Vmm>>( |
126 | this, jcp.post_ops, static_params); |
127 | } |
128 | if (!isa_has_bf16(jcp.isa)) |
129 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
130 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
131 | bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_5); |
132 | } |
133 | |
134 | template <typename Vmm> |
135 | void _jit_avx512_core_bf16_fwd_kernel<Vmm>::prepare_dst(int ur_w) { |
136 | for (int k = 0; k < jcp.nb_oc_blocking; k++) |
137 | for (int j = 0; j < ur_w; j++) { |
138 | Vmm vmm = vmm_dst(j, k); |
139 | vpxord(vmm, vmm, vmm); |
140 | } |
141 | } |
142 | |
143 | template <typename Vmm> |
144 | int _jit_avx512_core_bf16_fwd_kernel<Vmm>::vmm_dst_idx( |
145 | int i_ur, int i_oc) const { |
146 | const int idx = i_ur * jcp.nb_oc_blocking + i_oc; |
147 | assert(idx < ker_reg_base_idx); |
148 | return idx; |
149 | } |
150 | |
151 | template <typename Vmm> |
152 | Vmm _jit_avx512_core_bf16_fwd_kernel<Vmm>::vmm_dst(int i_ur, int i_oc) const { |
153 | return Vmm(vmm_dst_idx(i_ur, i_oc)); |
154 | } |
155 | |
156 | template <typename F> |
157 | static void iterate(const int nb_oc_block, const int ur_w, const bool mask_tail, |
158 | const bool force_masking, const F &f) { |
159 | for (int k = 0; k < nb_oc_block; k++) { |
160 | const bool mask_flag |
161 | = force_masking || (mask_tail && k + 1 == nb_oc_block); |
162 | for (int j = 0; j < ur_w; j++) |
163 | f(mask_flag, k, j); |
164 | } |
165 | } |
166 | template <typename F> |
167 | static void iterate(const int nb_oc_block, const int ur_w, const F &f) { |
168 | iterate(nb_oc_block, ur_w, false, false, f); |
169 | } |
170 | |
171 | template <typename Vmm> |
172 | void _jit_avx512_core_bf16_fwd_kernel<Vmm>::apply_postops(int ur_w) { |
173 | if (jcp.with_eltwise || jcp.with_binary) { |
174 | injector_utils::vmm_index_set_t vmm_idxs; |
175 | if (jcp.with_binary) { |
176 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, |
177 | rhs_arg_params_tail; |
178 | const auto mask_tail = jcp.oc_without_padding % jcp.simd_w; |
179 | const bool oc_blk_is_smaller_than_vmm |
180 | = jcp.oc_block < isa_simd_width_; |
181 | iterate(jcp.nb_oc_blocking, ur_w, mask_tail, |
182 | oc_blk_is_smaller_than_vmm, |
183 | [&](const bool mask_flag, const int k, const int j) { |
184 | const size_t aux_output_l_off = get_dst_offset(j, k); |
185 | const auto vmm_idx = vmm_dst_idx(j, k); |
186 | vmm_idxs.emplace(vmm_idx); |
187 | |
188 | rhs_arg_params_tail.vmm_idx_to_out_reg.emplace( |
189 | vmm_idx, reg_dst); |
190 | rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace( |
191 | vmm_idx, aux_output_l_off); |
192 | if (mask_flag) |
193 | rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx); |
194 | }); |
195 | rhs_arg_params = rhs_arg_params_tail; |
196 | rhs_arg_params.vmm_tail_idx_.clear(); |
197 | |
198 | Label postops_done; |
199 | if (mask_tail || oc_blk_is_smaller_than_vmm) { |
200 | Label postops_no_tail; |
201 | if (mask_tail) { |
202 | test(byte[param1 + GET_OFF(load_work)], jcp.oc_block - 1); |
203 | jz(postops_no_tail, T_NEAR); |
204 | } |
205 | postops_injector_->compute_vector_range( |
206 | vmm_idxs, rhs_arg_params_tail); |
207 | jmp(postops_done, T_NEAR); |
208 | L(postops_no_tail); |
209 | } |
210 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
211 | L(postops_done); |
212 | |
213 | } else { |
214 | iterate(jcp.nb_oc_blocking, ur_w, |
215 | [&](const bool, const int k, const int j) { |
216 | vmm_idxs.emplace(vmm_dst_idx(j, k)); |
217 | }); |
218 | postops_injector_->compute_vector_range(vmm_idxs); |
219 | } |
220 | } |
221 | } |
222 | |
223 | template <typename Vmm> |
224 | void _jit_avx512_core_bf16_fwd_kernel<Vmm>::store_dst(int ur_w) { |
225 | Label store_label; |
226 | const int oc_tail = jcp.oc_tail; |
227 | if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16(); |
228 | |
229 | if (jcp.with_sum) { |
230 | for (int k = 0; k < jcp.nb_oc_blocking; k++) { |
231 | for (int j = 0; j < ur_w; j++) { |
232 | // mask only needed for last oc_block |
233 | bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking; |
234 | Vmm vmm = vmm_dst(j, k); |
235 | size_t aux_dst_offset = get_dst_offset(j, k); |
236 | if (jcp.dst_dt == data_type::bf16) { |
237 | vpmovzxwd(may_be_mask_vmm(vmm_prev_dst, mask_flag, true), |
238 | make_safe_addr( |
239 | reg_dst, aux_dst_offset, reg_long_offt)); |
240 | vpslld(vmm_prev_dst, vmm_prev_dst, 16); |
241 | vaddps(vmm, vmm_prev_dst); |
242 | } else { |
243 | vaddps(may_be_mask_vmm(vmm, mask_flag, true), |
244 | make_safe_addr( |
245 | reg_dst, aux_dst_offset, reg_long_offt)); |
246 | } |
247 | } |
248 | } |
249 | } |
250 | |
251 | if (jcp.with_bias) { |
252 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
253 | for (int k = 0; k < jcp.nb_oc_blocking; k++) { |
254 | int bias_offset = jcp.typesize_bia * k * jcp.oc_block; |
255 | for (int j = 0; j < ur_w; j++) { |
256 | // mask only needed for last oc_block |
257 | bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking; |
258 | Vmm vmm = vmm_dst(j, k); |
259 | if (jcp.bia_dt == data_type::bf16) { |
260 | vpmovzxwd(may_be_mask_vmm(vmm_bias, mask_flag, true), |
261 | EVEX_compress_addr(reg_bias, bias_offset)); |
262 | vpslld(vmm_bias, vmm_bias, 16); |
263 | vaddps(vmm, vmm_bias); |
264 | } else |
265 | vaddps(may_be_mask_vmm(vmm, mask_flag, true), |
266 | EVEX_compress_addr(reg_bias, bias_offset)); |
267 | } |
268 | } |
269 | } |
270 | |
271 | apply_postops(ur_w); |
272 | |
273 | L(store_label); |
274 | if (jcp.dst_dt == data_type::f32) { |
275 | for (int k = 0; k < jcp.nb_oc_blocking; k++) |
276 | for (int j = 0; j < ur_w; j++) { |
277 | Vmm vmm = vmm_dst(j, k); |
278 | size_t aux_dst_offset = get_dst_offset(j, k); |
279 | auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset); |
280 | // mask only needed for last oc_block |
281 | bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking |
282 | && is_dst_layout_nxc(); |
283 | vmovups(addr, may_be_mask_vmm(vmm, mask_flag, false)); |
284 | } |
285 | } else if (jcp.dst_dt == data_type::bf16) { |
286 | if (isa_has_bf16(jcp.isa) && is_dst_layout_nxc()) { |
287 | // Optimization: use single store instruction for pair of the |
288 | // nearest vectors along OC dimension |
289 | for (int j = 0; j < ur_w; j++) { |
290 | int k = 0; |
291 | for (; k < rnd_dn(jcp.nb_oc_blocking, 2); k += 2) { |
292 | Vmm vmm = vmm_dst(j, k); |
293 | Vmm vmm_next = vmm_dst(j, k + 1); |
294 | size_t aux_dst_offset = get_dst_offset(j, k); |
295 | auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset); |
296 | vcvtne2ps2bf16(vmm, vmm_next, vmm); |
297 | // mask only needed for last oc_block |
298 | bool mask_flag = oc_tail && k + 2 == jcp.nb_oc_blocking; |
299 | vmovdqu16( |
300 | addr, may_be_mask_vmm(vmm, mask_flag, false, true)); |
301 | } |
302 | if (jcp.nb_oc_blocking % 2 != 0) { |
303 | Vmm vmm = vmm_dst(j, k); |
304 | auto vmm_down = Vmm_down_t(vmm.getIdx()); |
305 | size_t aux_dst_offset = get_dst_offset(j, k); |
306 | auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset); |
307 | vcvtneps2bf16(vmm_down, vmm); |
308 | // for xmm, upper half is zero after conversion to |
309 | // bf16, so mask always & mask for tails |
310 | bool mask_flag = jcp.simd_w == 4 || oc_tail; |
311 | vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag)); |
312 | } |
313 | } |
314 | } else if (isa_has_bf16(jcp.isa) /* !is_dst_layout_nxc() */) { |
315 | // Optimization: use single store instruction for pair of the |
316 | // nearest vectors along WIDTH dimension |
317 | for (int k = 0; k < jcp.nb_oc_blocking; k++) { |
318 | int n_2bf2ps = (ur_w / 2) * 2, j = 0; |
319 | for (j = 0; j < n_2bf2ps; j += 2) { |
320 | size_t aux_dst_offset = get_dst_offset(j, k); |
321 | auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset); |
322 | |
323 | auto vmm_str = vmm_src(j, jcp.nb_oc_blocking); |
324 | vcvtne2ps2bf16(vmm_str, vmm_dst(j + 1, k), vmm_dst(j, k)); |
325 | vmovups(addr, vmm_str); |
326 | } |
327 | if (j < ur_w) { |
328 | size_t aux_dst_offset = get_dst_offset(j, k); |
329 | |
330 | auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset); |
331 | auto vmm_down_str = vmm_src_down(j, jcp.nb_oc_blocking); |
332 | vcvtneps2bf16(vmm_down_str, vmm_dst(j, k)); |
333 | // for xmm, upper half is zero after conversion to |
334 | // bf16, so mask always. |
335 | const bool mask_flag = jcp.simd_w == 4; |
336 | vmovdqu16(addr, may_be_mask_vmm(vmm_down_str, mask_flag)); |
337 | } |
338 | } |
339 | } else { |
340 | for (int k = 0; k < jcp.nb_oc_blocking; k++) |
341 | for (int j = 0; j < ur_w; j++) { |
342 | Vmm vmm = vmm_dst(j, k); |
343 | size_t aux_dst_offset = get_dst_offset(j, k); |
344 | auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset); |
345 | auto vmm_down = vmm_src_down(0, jcp.nb_oc_blocking); |
346 | bf16_emu_->vcvtneps2bf16( |
347 | Ymm(vmm_down.getIdx()), Zmm(vmm.getIdx())); |
348 | bool mask_flag = (oc_tail && k + 1 == jcp.nb_oc_blocking |
349 | && is_dst_layout_nxc()) |
350 | // for xmm, upper half is zero after conversion to |
351 | // bf16, so mask always & mask for tails |
352 | || jcp.simd_w == 4; |
353 | vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag)); |
354 | } |
355 | } |
356 | } else |
357 | assert(!"unsupported destination type" ); |
358 | } |
359 | |
360 | template <typename Vmm> |
361 | void _jit_avx512_core_bf16_fwd_kernel<Vmm>::compute_loop( |
362 | int ur_w, int pad_l, int pad_r) { |
363 | Label kh_label, kd_label; |
364 | const int ic_tail = jcp.ic_tail; |
365 | const int ic_step = 2; |
366 | |
367 | /* max_src_offset is explicitly used in the 1st convolution. |
368 | * Set its value so that accessing the double-word memory |
369 | * referenced by ptr[src_base + offset] is safe whenever |
370 | * 0 <= offset < max_src_offset |
371 | * |
372 | * Note: Since the arguments pad_l, pad_r might not exactly match |
373 | * with jcp.l_pad and jcp.r_pad respectively so this value needs to be |
374 | * computed separately for each invocation of the compute_loop. |
375 | */ |
376 | dim_t max_src_offset = 0; |
377 | if (jcp.is_1stconv || ic_tail) { |
378 | for (int ki = 0; ki < jcp.kw; ki++) { |
379 | int ow_fst = get_ow_start(ki, pad_l); |
380 | int ow_last = get_ow_end(ur_w, ki, pad_r) - 1; |
381 | if (ow_fst > ow_last) continue; |
382 | int ic_last = rnd_up(nstl::min(jcp.ic_block, |
383 | nstl::max(jcp.ic, ic_tail)), |
384 | ic_step) |
385 | - ic_step; |
386 | |
387 | dim_t src_offset = get_src_offset( |
388 | ic_last, filter_w_to_src(ki, ow_last, pad_l)); |
389 | if (src_offset > max_src_offset) max_src_offset = src_offset; |
390 | } |
391 | } |
392 | |
393 | prepare_dst(ur_w); |
394 | |
395 | Label skip_compute_loop; |
396 | if (jcp.ndims == 5) { |
397 | mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]); |
398 | if ((jcp.dilate_d >= jcp.id) |
399 | || (jcp.kd - 1) * (jcp.dilate_d + 1) |
400 | < nstl::max(jcp.f_pad, jcp.back_pad)) { |
401 | cmp(reg_kj, 0); |
402 | je(skip_compute_loop, T_NEAR); |
403 | } |
404 | } |
405 | mov(reg_kj, reg_kh); |
406 | if ((jcp.dilate_h >= jcp.ih) |
407 | || (jcp.kh - 1) * (jcp.dilate_h + 1) |
408 | < nstl::max(jcp.t_pad, jcp.b_pad)) { |
409 | cmp(reg_kj, 0); |
410 | je(skip_compute_loop, T_NEAR); |
411 | } |
412 | |
413 | // IC loop |
414 | Label icb_label; |
415 | mov(reg_ic, jcp.ic); |
416 | L(icb_label); |
417 | |
418 | if (jcp.ndims == 5) { |
419 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
420 | mov(ptr[rsp + off_reg_ker_], reg_ker); |
421 | mov(ptr[rsp + off_reg_src_], reg_src); |
422 | |
423 | L(kd_label); |
424 | } |
425 | |
426 | mov(aux_reg_src, reg_src); |
427 | mov(aux_reg_ker, reg_ker); |
428 | |
429 | mov(reg_kj, reg_kh); |
430 | |
431 | std::vector<Label> ic_tail_jmp(jcp.kw); |
432 | L(kh_label); |
433 | { |
434 | for (int ki = 0; ki < jcp.kw; ki++) { |
435 | int ow_start = get_ow_start(ki, pad_l); |
436 | int ow_end = get_ow_end(ur_w, ki, pad_r); |
437 | for (int ic = 0; |
438 | ic < rnd_up(nstl::min(jcp.ic_block, jcp.ic), ic_step); |
439 | ic += ic_step) { |
440 | if (ic_tail && ic == rnd_up(ic_tail, ic_step)) { |
441 | // insert this check at most once per icb, no more. |
442 | cmp(reg_ic, ic_tail); |
443 | je(ic_tail_jmp[ki], T_NEAR); |
444 | } |
445 | for (int oi = ow_start; oi < ow_end; oi++) { |
446 | dim_t src_offset = get_src_offset( |
447 | ic, filter_w_to_src(ki, oi, pad_l)); |
448 | auto vmm_in = vmm_src(oi, jcp.nb_oc_blocking); |
449 | const auto addr_base = EVEX_compress_addr_safe( |
450 | aux_reg_src, src_offset, reg_long_offt); |
451 | const bool tail_load |
452 | = ic_tail && ic == rnd_dn(ic_tail, ic_step); |
453 | if (jcp.is_1stconv || tail_load) { |
454 | const bool need_single_load |
455 | = (ic + 1 == jcp.ic || ic + 1 == ic_tail); |
456 | const bool safe_overstep = (src_offset < max_src_offset) |
457 | && !is_src_layout_nxc(); |
458 | |
459 | /* For the comment below, let us define three words |
460 | * x_b = ptr[addr_base] and x_s = ptr[addr_strided] |
461 | * x_g = ptr[addr_base + 2] |
462 | * |
463 | * For single load case: |
464 | * Without overstep zmm_in register is loaded as |
465 | * [0, x_b, ..., 0, x_b, 0, x_b] |
466 | * On the other hand, "with overstep" zmm_in register |
467 | * is loaded as |
468 | * [x_g, x_b, ..., x_g, x_b, x_g, x_b] |
469 | * where x_g is a garbage word. |
470 | * |
471 | * Note: |
472 | * 1. In single load case with safe_overstep enabled, |
473 | * it is implicitly assumed that the element in zmm_wei |
474 | * register corresponding to the "garbage value x_g" in |
475 | * zmm_in register is zero. |
476 | * 2. One can have potential problem when x_g is |
477 | * either Inf or NaN since it is multiplied by zero |
478 | * in accumulation. But as x_g is a "valid input" |
479 | * for different offset so one might assume that x_g is |
480 | * neither Inf nor Nan. |
481 | * |
482 | * For non single load case: |
483 | * zmm_in register is loaded as |
484 | * [x_s, x_b, ...., x_s, x_b, x_s, x_b] |
485 | */ |
486 | if (tail_load) { |
487 | if (need_single_load) { |
488 | Label mask_load, load_done; |
489 | cmp(reg_ic, ic + ic_step); |
490 | jl(mask_load, T_NEAR); |
491 | vpbroadcastd(vmm_in, addr_base); |
492 | jmp(load_done, T_NEAR); |
493 | L(mask_load); |
494 | vpbroadcastw(vmm_in | odd_load_mask | T_z, |
495 | addr_base); |
496 | L(load_done); |
497 | } else { |
498 | vpbroadcastd(vmm_in, addr_base); |
499 | } |
500 | } else if (need_single_load && !safe_overstep) |
501 | vpbroadcastw( |
502 | vmm_in | odd_load_mask | T_z, addr_base); |
503 | else if (IMPLICATION(!is_src_layout_nxc(), |
504 | need_single_load && safe_overstep)) |
505 | vpbroadcastd(vmm_in, addr_base); |
506 | else { |
507 | const auto addr_strided |
508 | = EVEX_compress_addr_safe(aux_reg_src, |
509 | src_offset + get_src_offset(1, 0), |
510 | reg_long_offt); |
511 | vpbroadcastd(vmm_in, addr_base); |
512 | vpbroadcastw(vmm_in | even_load_mask, addr_strided); |
513 | } |
514 | } else { |
515 | vpbroadcastd(vmm_in, addr_base); |
516 | } |
517 | } |
518 | for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { |
519 | auto wei_off = get_kernel_offset(kk, ic, ki); |
520 | vmovups(vmm_wei, |
521 | EVEX_compress_addr_safe( |
522 | aux_reg_ker, wei_off, reg_long_offt)); |
523 | for (int oi = ow_start; oi < ow_end; oi++) { |
524 | auto acc = vmm_dst(oi, kk); |
525 | auto src = vmm_src(oi, jcp.nb_oc_blocking); |
526 | if (isa_has_bf16(jcp.isa)) { |
527 | vdpbf16ps(acc, vmm_wei, src); |
528 | } else |
529 | bf16_emu_->vdpbf16ps(Zmm(acc.getIdx()), |
530 | Zmm(vmm_wei.getIdx()), Zmm(src.getIdx())); |
531 | } |
532 | } |
533 | } |
534 | L(ic_tail_jmp[ki]); |
535 | } |
536 | safe_add(aux_reg_ker, get_kernel_offset(0, 0, 0, 1), reg_long_offt); |
537 | safe_add(aux_reg_src, get_src_offset(0, filter_h_to_src(1)), |
538 | reg_long_offt); |
539 | |
540 | dec(reg_kj); |
541 | cmp(reg_kj, 0); |
542 | jg(kh_label, T_NEAR); |
543 | } |
544 | |
545 | if (jcp.ndims == 5) { |
546 | safe_add(reg_src, get_src_offset(0, filter_d_to_src(1)), reg_long_offt); |
547 | safe_add(reg_ker, get_kernel_offset(0, 0, 0, 0, 1), reg_long_offt); |
548 | dec(reg_ki); |
549 | cmp(reg_ki, 0); |
550 | jg(kd_label, T_NEAR); |
551 | |
552 | mov(reg_ker, ptr[rsp + off_reg_ker_]); |
553 | mov(reg_src, ptr[rsp + off_reg_src_]); |
554 | } |
555 | |
556 | // End of IC Loop |
557 | dim_t src_step = get_src_offset(jcp.ic_block, 0); |
558 | const size_t ker_step = get_kernel_offset(0, jcp.ic_block, 0); |
559 | safe_add(reg_src, src_step, reg_long_offt); |
560 | safe_add(reg_ker, ker_step, reg_long_offt); |
561 | |
562 | sub(reg_ic, jcp.ic_block); |
563 | cmp(reg_ic, 0); |
564 | jg(icb_label, T_NEAR); |
565 | |
566 | safe_sub(reg_src, src_step * jcp.nb_ic, reg_long_offt); |
567 | safe_sub(reg_ker, ker_step * jcp.nb_ic, reg_long_offt); |
568 | |
569 | L(skip_compute_loop); |
570 | store_dst(ur_w); |
571 | } |
572 | |
573 | template <typename Vmm> |
574 | void _jit_avx512_core_bf16_fwd_kernel<Vmm>::generate() { |
575 | int iw = jcp.iw; |
576 | int ow = jcp.ow; |
577 | int ow_block = jcp.ow_block; |
578 | int nb_ow = jcp.nb_ow; |
579 | int kw = jcp.kw; |
580 | int l_pad = jcp.l_pad; |
581 | int ur_w = jcp.ur_w; |
582 | int ur_w_tail = jcp.ur_w_tail; |
583 | int stride_w = jcp.stride_w; |
584 | |
585 | auto src_shift = get_src_offset(0, filter_w_to_src(0, ur_w)); |
586 | auto dst_shift = get_dst_offset(ur_w, 0); |
587 | |
588 | auto src_shift_pad = get_src_offset(0, filter_w_to_src(0, ur_w, l_pad)); |
589 | auto src_shift_pad_second_block |
590 | = get_src_offset(0, filter_w_to_src(0, 0, l_pad)); |
591 | |
592 | preamble(); |
593 | if (jcp.ndims == 5) sub(rsp, stack_space_needed_); |
594 | |
595 | if (jcp.is_1stconv || jcp.ic_tail) { |
596 | Xbyak::Reg64 reg_alt_mask = r8; |
597 | const auto odd_mask = size_t {0x5555555555555555}; |
598 | const auto even_mask = size_t {0xaaaaaaaaaaaaaaaa}; |
599 | mov(reg_alt_mask, odd_mask); |
600 | kmovq(odd_load_mask, reg_alt_mask); |
601 | mov(reg_alt_mask, even_mask); |
602 | kmovq(even_load_mask, reg_alt_mask); |
603 | } |
604 | |
605 | if (jcp.simd_w == 4) { |
606 | auto reg_tail_32 = reg_oc.cvt32(); |
607 | mov(reg_tail_32, (1 << jcp.simd_w) - 1); |
608 | kmovb(k_oc_tail_mask, reg_tail_32); |
609 | } |
610 | |
611 | if (jcp.oc_tail) { |
612 | Label done; |
613 | // dummy mask all 1's |
614 | if (jcp.simd_w != 4) { // simd_w == 4, has its dummy mask set already |
615 | kxnord(k_oc_tail_mask, k_oc_tail_mask, k_oc_tail_mask); |
616 | } |
617 | // To account for special store optimization, where two oc_blocks are |
618 | // combined with one single write, extend the mask for 32bits (32 bf16s) |
619 | const bool need_extended_mask = jcp.dst_dt == data_type::bf16 |
620 | && isa_has_bf16(jcp.isa) && jcp.nb_oc_blocking > 1; |
621 | if (need_extended_mask) |
622 | kxnord(k_oc_tail_mask_extended, k_oc_tail_mask_extended, |
623 | k_oc_tail_mask_extended); |
624 | |
625 | test(byte[param1 + GET_OFF(load_work)], jcp.oc_block - 1); |
626 | jz(done, T_NEAR); |
627 | auto reg_tail_32 = reg_oc.cvt32(); |
628 | mov(reg_tail_32, (1 << jcp.oc_tail) - 1); |
629 | kmovd(k_oc_tail_mask, reg_tail_32); |
630 | kmovd(postops_mask, reg_tail_32); |
631 | if (need_extended_mask) { |
632 | mov(reg_tail_32, (1 << (jcp.oc_tail + jcp.simd_w)) - 1); |
633 | kmovd(k_oc_tail_mask_extended, reg_tail_32); |
634 | } |
635 | L(done); |
636 | } else if (jcp.with_binary) |
637 | if (jcp.oc_block != isa_simd_width_) { |
638 | const int mask = (1 << jcp.oc_block) - 1; |
639 | const Reg32 regw_tmp = reg_oi.cvt32(); |
640 | mov(regw_tmp, mask); |
641 | kmovd(postops_mask, regw_tmp); |
642 | } |
643 | |
644 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
645 | mov(reg_dst, ptr[param1 + GET_OFF(dst)]); |
646 | mov(reg_ker, ptr[param1 + GET_OFF(filt)]); |
647 | mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); |
648 | |
649 | int r_pad = nstl::max(0, jcp.r_pad); |
650 | int n_oi = ow / ur_w; |
651 | int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w, |
652 | calculate_extended_filter_size(kw, jcp.dilate_w)); |
653 | |
654 | if (!is_ow_threading_on(jcp)) { |
655 | // ow is being processed as a whole - with left and right paddings |
656 | if (r_pad1 > 0) n_oi--; |
657 | |
658 | xor_(reg_oi, reg_oi); |
659 | if (ow == ur_w) { |
660 | compute_loop(ur_w, l_pad, r_pad); |
661 | } else { |
662 | if (n_oi == 0) { |
663 | compute_loop(ur_w, l_pad, r_pad1); |
664 | add(reg_src, src_shift_pad); |
665 | add(reg_dst, dst_shift); |
666 | if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); } |
667 | } else { |
668 | if (l_pad > 0) { |
669 | compute_loop(ur_w, l_pad, 0); |
670 | add(reg_src, src_shift_pad); |
671 | add(reg_dst, dst_shift); |
672 | inc(reg_oi); |
673 | } |
674 | if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { |
675 | Label ow_loop_label; |
676 | L(ow_loop_label); |
677 | { |
678 | compute_loop(ur_w, 0, 0); |
679 | add(reg_src, src_shift); |
680 | add(reg_dst, dst_shift); |
681 | |
682 | inc(reg_oi); |
683 | cmp(reg_oi, n_oi); |
684 | jl(ow_loop_label, T_NEAR); |
685 | } |
686 | } |
687 | if (r_pad1 > 0) { |
688 | compute_loop(ur_w, 0, r_pad1); |
689 | add(reg_src, src_shift); |
690 | add(reg_dst, dst_shift); |
691 | } |
692 | if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); } |
693 | } |
694 | } |
695 | } else { |
696 | // ow block is only processed. |
697 | // Number of block is passed as parameter owb, |
698 | // and padding processing depends on this number. |
699 | |
700 | Label end_label, last_oi_label, middle_ow_blocks_label, tail_label; |
701 | Label oi_loop_label, oi_loop_start_label, oi_loop_end_label; |
702 | |
703 | assert(ow_block % ur_w == 0); |
704 | int n_oi_not_last_ow_block = ow_block / ur_w; |
705 | // to simplify code (and general regs usage), |
706 | // size of ow block must be >= 2 * ur_w |
707 | assert(n_oi_not_last_ow_block > 1); |
708 | int n_oi_next_last_ow_block = n_oi_not_last_ow_block; |
709 | int n_oi_first_ow_block = n_oi_not_last_ow_block; |
710 | |
711 | int n_oi_last_ow_block = (ow - ow_block * (nb_ow - 1)) / ur_w; |
712 | |
713 | // prepare right padding |
714 | bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; |
715 | bool first_ow_block_padded |
716 | = next_last_ow_block_padded && jcp.nb_ow == 2; |
717 | bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0; |
718 | |
719 | if (last_ow_block_padded) |
720 | n_oi_last_ow_block--; |
721 | else if (first_ow_block_padded) |
722 | n_oi_first_ow_block--; |
723 | else if (next_last_ow_block_padded) |
724 | n_oi_next_last_ow_block--; |
725 | |
726 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
727 | cmp(reg_owb, 0); // is that the first ow-block ? |
728 | jg(middle_ow_blocks_label, T_NEAR); |
729 | |
730 | // the first ow block, compute left padding |
731 | |
732 | mov(reg_oi, n_oi_first_ow_block); |
733 | if (l_pad > 0) { |
734 | compute_loop(ur_w, l_pad, 0); |
735 | add(reg_src, src_shift_pad); |
736 | add(reg_dst, dst_shift); |
737 | dec(reg_oi); |
738 | } |
739 | jmp(oi_loop_label, T_NEAR); |
740 | |
741 | // middle or last ow block entry |
742 | |
743 | L(middle_ow_blocks_label); |
744 | |
745 | if (l_pad > 0) { |
746 | // just to consider left padding, not compute |
747 | add(reg_src, src_shift_pad_second_block); |
748 | } |
749 | |
750 | // set number of iteration for oi-loop |
751 | cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? |
752 | mov(reg_oi, n_oi_last_ow_block); |
753 | je(oi_loop_label, T_NEAR); |
754 | cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? |
755 | mov(reg_oi, n_oi_next_last_ow_block); |
756 | je(oi_loop_label, T_NEAR); |
757 | mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks |
758 | |
759 | // oi loop w/o padding |
760 | L(oi_loop_label); |
761 | L(oi_loop_start_label); |
762 | cmp(reg_oi, 0); |
763 | jle(oi_loop_end_label, T_NEAR); |
764 | |
765 | compute_loop(ur_w, 0, 0); |
766 | add(reg_src, src_shift); |
767 | add(reg_dst, dst_shift); |
768 | dec(reg_oi); |
769 | jmp(oi_loop_start_label, T_NEAR); |
770 | L(oi_loop_end_label); |
771 | |
772 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
773 | |
774 | cmp(reg_owb, 0); // first ow-block ? |
775 | if (first_ow_block_padded) { |
776 | je(last_oi_label, T_NEAR); |
777 | } else { |
778 | je(end_label, T_NEAR); |
779 | } |
780 | cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? |
781 | jl(end_label, T_NEAR); |
782 | if (next_last_ow_block_padded) { |
783 | je(last_oi_label, T_NEAR); |
784 | } else { |
785 | je(end_label, T_NEAR); |
786 | } |
787 | // that is last block |
788 | if (!last_ow_block_padded) { jmp(tail_label, T_NEAR); } |
789 | |
790 | // last oi block with right padding |
791 | L(last_oi_label); |
792 | compute_loop(ur_w, 0, r_pad1); |
793 | add(reg_src, src_shift); |
794 | add(reg_dst, dst_shift); |
795 | |
796 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
797 | cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? |
798 | jl(end_label, T_NEAR); |
799 | |
800 | L(tail_label); |
801 | if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); } |
802 | L(end_label); |
803 | } |
804 | |
805 | if (jcp.ndims == 5) add(rsp, stack_space_needed_); |
806 | postamble(); |
807 | |
808 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
809 | } |
810 | |
811 | void jit_avx512_core_bf16_fwd_kernel::init_scratchpad( |
812 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
813 | using namespace memory_tracking::names; |
814 | if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) { |
815 | assert(jcp.ngroups == 1); |
816 | scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia); |
817 | } |
818 | } |
819 | |
820 | status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp, |
821 | const convolution_desc_t &cd, memory_desc_t &src_md, |
822 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
823 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) { |
824 | |
825 | using namespace prop_kind; |
826 | |
827 | const memory_desc_wrapper src_d(&src_md); |
828 | const memory_desc_wrapper weights_d(&weights_md); |
829 | const memory_desc_wrapper dst_d(&dst_md); |
830 | const memory_desc_wrapper bias_d(&bias_md); |
831 | |
832 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
833 | int ndims = src_d.ndims(); |
834 | |
835 | jcp = zero<decltype(jcp)>(); |
836 | jcp.nthr = nthreads; |
837 | jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16 |
838 | : bf16_emulation_t::get_isa(); |
839 | jcp.has_vnni = true; |
840 | jcp.ndims = ndims; |
841 | jcp.prop_kind = cd.prop_kind; |
842 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
843 | jcp.mb = src_d.dims()[0]; |
844 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
845 | jcp.oc_without_padding = jcp.oc; |
846 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
847 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
848 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
849 | jcp.iw = src_d.dims()[ndims - 1]; |
850 | jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; |
851 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; |
852 | jcp.ow = dst_d.dims()[ndims - 1]; |
853 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
854 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
855 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
856 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
857 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
858 | jcp.l_pad = cd.padding[0][ndims - 3]; |
859 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
860 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
861 | jcp.stride_w = cd.strides[ndims - 3]; |
862 | jcp.dst_dt = dst_d.data_type(); |
863 | |
864 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
865 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
866 | jcp.dilate_w = cd.dilates[ndims - 3]; |
867 | |
868 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
869 | |
870 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
871 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
872 | |
873 | jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef; |
874 | jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0; |
875 | |
876 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
877 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
878 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
879 | jcp.r_pad = calculate_end_padding( |
880 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
881 | jcp.b_pad = calculate_end_padding( |
882 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
883 | jcp.back_pad = calculate_end_padding( |
884 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); |
885 | bool kernel_outside_src = false || ext_kw <= jcp.l_pad |
886 | || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad |
887 | || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; |
888 | if (kernel_outside_src) return status::unimplemented; |
889 | |
890 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
891 | const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw); |
892 | const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c); |
893 | const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); |
894 | const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); |
895 | auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c, |
896 | dat_tag_nCx8c, dat_tag_nCx4c, dat_tag_ncx); |
897 | auto curr_dst_tag = dst_d.matches_one_of_tag( |
898 | dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); |
899 | bool is_data_layout_nxc = IMPLICATION(curr_src_tag != dat_tag_nxc, |
900 | src_d.format_kind() == format_kind::any) |
901 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
902 | dst_d.format_kind() == format_kind::any) |
903 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
904 | jcp.is_1stconv = is_1stconv(jcp); |
905 | |
906 | const int regs = isa_has_bf16(jcp.isa) ? 31 /* expl_bcast case */ : 26; |
907 | const bool ok_to_pad_channels = jcp.ngroups == 1 && !is_data_layout_nxc; |
908 | |
909 | jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
910 | |
911 | const bool ok_to_try_lower_zmm = true |
912 | && IMPLICATION(is_data_layout_nxc, |
913 | jcp.oc < jcp.simd_w && jcp.ic < jcp.simd_w |
914 | && jcp.ngroups > 1) |
915 | && !jcp.is_1stconv && !ok_to_pad_channels |
916 | && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0); |
917 | |
918 | if (ok_to_try_lower_zmm) { |
919 | for (auto simd : {8, 4}) { |
920 | if (jcp.ic % simd == 0 && jcp.oc % simd == 0) { |
921 | jcp.simd_w = simd; |
922 | break; |
923 | } |
924 | } |
925 | } |
926 | |
927 | jcp.oc_block = jcp.simd_w; |
928 | jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; |
929 | |
930 | if (ok_to_pad_channels) { |
931 | jcp.oc = rnd_up(jcp.oc, jcp.oc_block); |
932 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
933 | } |
934 | |
935 | if (!IMPLICATION(!is_data_layout_nxc, |
936 | jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0)) |
937 | return status::unimplemented; |
938 | |
939 | format_tag_t src_tag, dst_tag, wei_tag; |
940 | |
941 | if (jcp.simd_w == 8) { |
942 | assert(with_groups); |
943 | dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; |
944 | wei_tag = pick(ndims - 3, gOIw4i8o2i, gOIhw4i8o2i, gOIdhw4i8o2i); |
945 | } else if (jcp.simd_w == 4) { |
946 | assert(with_groups); |
947 | dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx4c; |
948 | wei_tag = pick(ndims - 3, gOIw2i4o2i, gOIhw2i4o2i, gOIdhw2i4o2i); |
949 | } else if (jcp.is_1stconv) { |
950 | dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
951 | src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_ncx; |
952 | wei_tag = pick(2 * ndims - 6 + with_groups, OwI16o2i, gOwI16o2i, |
953 | OhwI16o2i, gOhwI16o2i, OdhwI16o2i, gOdhwI16o2i); |
954 | } else { |
955 | dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
956 | wei_tag = pick(2 * ndims - 6 + with_groups, OIw8i16o2i, gOIw8i16o2i, |
957 | OIhw8i16o2i, gOIhw8i16o2i, OIdhw8i16o2i, gOIdhw8i16o2i); |
958 | } |
959 | |
960 | if (src_md.format_kind == format_kind::any) |
961 | CHECK(memory_desc_init_by_tag(src_md, src_tag)); |
962 | else if (curr_src_tag != src_tag) |
963 | return status::unimplemented; |
964 | jcp.src_tag = src_tag; |
965 | |
966 | if (dst_md.format_kind == format_kind::any) |
967 | CHECK(memory_desc_init_by_tag(dst_md, dst_tag)); |
968 | else if (curr_dst_tag != dst_tag) |
969 | return status::unimplemented; |
970 | jcp.dst_tag = dst_tag; |
971 | |
972 | if (weights_md.format_kind == format_kind::any) { |
973 | CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); |
974 | jcp.wei_tag = wei_tag; |
975 | } else { |
976 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
977 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
978 | } |
979 | |
980 | if (jcp.with_bias) { |
981 | if (bias_d.format_kind() == format_kind::any) |
982 | CHECK(memory_desc_init_by_tag(bias_md, x)); |
983 | } |
984 | |
985 | jcp.aligned_threads = 0; |
986 | |
987 | bool args_ok = true && jcp.ic <= src_d.padded_dims()[1] |
988 | && jcp.oc <= dst_d.padded_dims()[1] |
989 | && jcp.ic <= weights_d.padded_dims()[with_groups + 1] |
990 | && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; |
991 | if (!args_ok) return status::unimplemented; |
992 | |
993 | const auto &post_ops = attr.post_ops_; |
994 | jcp.with_sum = post_ops.find(primitive_kind::sum) != -1; |
995 | const int eltwise_ind = post_ops.find(primitive_kind::eltwise); |
996 | jcp.with_eltwise = eltwise_ind != -1; |
997 | if (jcp.with_eltwise) { |
998 | jcp.eltwise = post_ops.entry_[eltwise_ind].eltwise; |
999 | if (dst_d.data_type() == data_type::s32) return status::unimplemented; |
1000 | } |
1001 | const int binary_ind = post_ops.find(primitive_kind::binary); |
1002 | jcp.with_binary = binary_ind != -1; |
1003 | |
1004 | jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0; |
1005 | if (is_data_layout_nxc) |
1006 | jcp.oc_tail = jcp.oc % jcp.simd_w; |
1007 | else |
1008 | jcp.oc_tail = jcp.with_binary ? jcp.oc_without_padding % jcp.simd_w : 0; |
1009 | |
1010 | if (attr.set_default_formats(&dst_md) != status::success) |
1011 | return status::unimplemented; |
1012 | |
1013 | jcp.post_ops = post_ops; |
1014 | |
1015 | using namespace injector; |
1016 | static constexpr bool sum_at_pos_0_only = true; |
1017 | static constexpr bool sum_requires_scale_one = true; |
1018 | static constexpr bool sum_requires_zp_zero = true; |
1019 | const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, |
1020 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
1021 | sum_requires_zp_zero}); |
1022 | if (!post_ops_ok_) return status::unimplemented; |
1023 | |
1024 | jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block); |
1025 | jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block); |
1026 | jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; |
1027 | |
1028 | jcp.kernel_kind = expl_bcast; |
1029 | jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); |
1030 | for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) { |
1031 | int ur_w = regs / (jcp.nb_oc_blocking + 1); |
1032 | if (jcp.nb_oc % jcp.nb_oc_blocking == 0 |
1033 | && (jcp.l_pad <= ur_w |
1034 | && IMPLICATION(jcp.ow != 1, jcp.ow % ur_w != 1))) |
1035 | break; |
1036 | } |
1037 | |
1038 | jcp.ur_w = regs / (jcp.nb_oc_blocking + 1); |
1039 | if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; |
1040 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
1041 | |
1042 | jcp.ow_block = jcp.ow; |
1043 | if (is_ow_threading_available(jcp)) { |
1044 | const int L1_part = platform::get_per_core_cache_size(1) * 5 / 8; |
1045 | int size_src_chunk = jcp.typesize_in * jcp.ic_block * jcp.ur_w; |
1046 | int size_dst_chunk = jcp.typesize_out * jcp.oc_block |
1047 | * jcp.nb_oc_blocking * jcp.ur_w; |
1048 | int size_wei_chunk = jcp.typesize_in * jcp.oc_block * jcp.ic_block |
1049 | * jcp.nb_oc_blocking * jcp.kw; |
1050 | int nurw = (L1_part - size_wei_chunk) |
1051 | / (size_dst_chunk + size_src_chunk); |
1052 | // current design of generate() requires ow_block >= 2 * ur_w |
1053 | jcp.ow_block = jcp.ur_w * nstl::max(2, nurw); |
1054 | } |
1055 | jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); |
1056 | |
1057 | int r_pad_no_tail = nstl::max(0, |
1058 | calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw, |
1059 | jcp.stride_w, ext_kw)); |
1060 | if (jcp.l_pad > jcp.ur_w || r_pad_no_tail > jcp.ur_w) |
1061 | return status::unimplemented; |
1062 | |
1063 | /* adjust the thread decomposition |
1064 | * to improve the perf for small problem size |
1065 | * the threshold L1_cache_size/factor and the factor is empirical |
1066 | * simply set the thread to 4 for now |
1067 | * TODO: Add get_thr_eff func to get optimal thread number */ |
1068 | |
1069 | size_t wei_size = (size_t)sizeof(bfloat16_t) * jcp.ic * jcp.oc * jcp.kh |
1070 | * jcp.kw * jcp.kd; |
1071 | size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih |
1072 | * jcp.iw * jcp.id; |
1073 | size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh |
1074 | * jcp.ow * jcp.od; |
1075 | size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size); |
1076 | const unsigned int L1_cache_size = platform::get_per_core_cache_size(1); |
1077 | |
1078 | // The factor for 1d=1, 2d=2, 3d=4; |
1079 | int factor = nstl::max(1, (2 * (ndims - 3))); |
1080 | if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size / factor) { |
1081 | jcp.nthr = nstl::min(jcp.nthr, 4); |
1082 | } |
1083 | |
1084 | pick_loop_order(jcp); |
1085 | |
1086 | return status::success; |
1087 | } |
1088 | |
1089 | template <typename Vmm> |
1090 | void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::prepare_output(int ur_w) { |
1091 | for (int k = 0; k < jcp.nb_ic_blocking; k++) { |
1092 | for (int j = 0; j < ur_w; j++) { |
1093 | Vmm vmm = vmm_dsrc(j, k); |
1094 | vpxord(vmm, vmm, vmm); |
1095 | } |
1096 | } |
1097 | } |
1098 | |
1099 | template <typename Vmm> |
1100 | void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::store_output(int ur_w) { |
1101 | if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16(); |
1102 | const int ic_tail = jcp.ic_tail; |
1103 | |
1104 | if (jcp.dst_dt == data_type::f32) { |
1105 | for (int k = 0; k < jcp.nb_ic_blocking; k++) |
1106 | for (int j = 0; j < ur_w; j++) { |
1107 | Vmm vmm = vmm_dsrc(j, k); |
1108 | size_t aux_diff_src_offset = get_diff_src_offset(j, k); |
1109 | auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset); |
1110 | // mask only needed for last ic_block |
1111 | bool mask_flag = ic_tail && k + 1 == jcp.nb_ic_blocking |
1112 | && is_dsrc_layout_nxc(); |
1113 | vmovups(addr, may_be_mask_vmm(vmm, mask_flag, false)); |
1114 | } |
1115 | } else if (jcp.dst_dt == data_type::bf16) { |
1116 | if (isa_has_bf16(jcp.isa) && is_ddst_layout_nxc()) { |
1117 | // Optimization: use single store instruction for pair of the |
1118 | // nearest vectors along IC dimension |
1119 | for (int j = 0; j < ur_w; j++) { |
1120 | int k = 0; |
1121 | for (; k < rnd_dn(jcp.nb_ic_blocking, 2); k += 2) { |
1122 | Vmm vmm = vmm_dsrc(j, k); |
1123 | Vmm vmm_next = vmm_dsrc(j, k + 1); |
1124 | size_t aux_dsrc_offset = get_diff_src_offset(j, k); |
1125 | auto addr = EVEX_compress_addr(reg_src, aux_dsrc_offset); |
1126 | vcvtne2ps2bf16(vmm, vmm_next, vmm); |
1127 | bool mask_flag = ic_tail && k + 2 == jcp.nb_ic_blocking; |
1128 | vmovdqu16( |
1129 | addr, may_be_mask_vmm(vmm, mask_flag, false, true)); |
1130 | } |
1131 | if (jcp.nb_ic_blocking % 2 != 0) { |
1132 | Vmm vmm = vmm_dsrc(j, k); |
1133 | auto vmm_down = Vmm_down_t(vmm.getIdx()); |
1134 | size_t aux_dsrc_offset = get_diff_src_offset(j, k); |
1135 | auto addr = EVEX_compress_addr(reg_src, aux_dsrc_offset); |
1136 | vcvtneps2bf16(vmm_down, vmm); |
1137 | // for xmm, upper half is zero after conversion to |
1138 | // bf16, so mask always & mask for tails |
1139 | bool mask_flag = jcp.simd_w == 4 || ic_tail; |
1140 | vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag)); |
1141 | } |
1142 | } |
1143 | } else if (isa_has_bf16(jcp.isa) /* && !is_ddst_layout_nxc() */) { |
1144 | // Optimization: use single store instruction for pair of the |
1145 | // nearest vectors along WIDTH dimension |
1146 | int store_idx = 0; |
1147 | const int max_regs = 32; |
1148 | const int free_regs_start_idx = jcp.ur_w * jcp.nb_ic_blocking; |
1149 | const int num_regs_available = max_regs - free_regs_start_idx; |
1150 | int reg_idx = 0; |
1151 | for (int k = 0; k < jcp.nb_ic_blocking; k++) { |
1152 | int n_2bf2ps = (ur_w / 2) * 2, j = 0; |
1153 | for (j = 0; j < n_2bf2ps; j += 2) { |
1154 | reg_idx = free_regs_start_idx |
1155 | + store_idx % num_regs_available; |
1156 | assert(reg_idx < max_regs); |
1157 | size_t aux_diff_src_offset = get_diff_src_offset(j, k); |
1158 | auto addr |
1159 | = EVEX_compress_addr(reg_src, aux_diff_src_offset); |
1160 | |
1161 | auto vmm_str = Vmm(reg_idx); |
1162 | vcvtne2ps2bf16(vmm_str, vmm_dsrc(j + 1, k), vmm_dsrc(j, k)); |
1163 | vmovups(addr, vmm_str); |
1164 | store_idx++; |
1165 | } |
1166 | if (j < ur_w) { |
1167 | reg_idx = free_regs_start_idx |
1168 | + store_idx % num_regs_available; |
1169 | assert(reg_idx < max_regs); |
1170 | |
1171 | size_t aux_diff_src_offset = get_diff_src_offset(j, k); |
1172 | auto addr |
1173 | = EVEX_compress_addr(reg_src, aux_diff_src_offset); |
1174 | auto vmm_down_str = Vmm_down_t(reg_idx); |
1175 | vcvtneps2bf16(vmm_down_str, vmm_dsrc(j, k)); |
1176 | // for xmm, upper half is zero after conversion to |
1177 | // bf16, so mask always. |
1178 | bool mask_flag = jcp.simd_w == 4; |
1179 | vmovdqu16(addr, may_be_mask_vmm(vmm_down_str, mask_flag)); |
1180 | store_idx++; |
1181 | } |
1182 | } |
1183 | } else { |
1184 | for (int k = 0; k < jcp.nb_ic_blocking; k++) |
1185 | for (int j = 0; j < ur_w; j++) { |
1186 | Vmm vmm = vmm_dsrc(j, k); |
1187 | size_t aux_diff_src_offset = get_diff_src_offset(j, k); |
1188 | auto addr |
1189 | = EVEX_compress_addr(reg_src, aux_diff_src_offset); |
1190 | auto vmm_down = vmm_ddst_down(0); |
1191 | bf16_emu_->vcvtneps2bf16( |
1192 | Ymm(vmm_down.getIdx()), Zmm(vmm.getIdx())); |
1193 | bool mask_flag = (ic_tail && k + 1 == jcp.nb_ic_blocking |
1194 | && is_dsrc_layout_nxc()) |
1195 | // for xmm, upper half is zero after conversion to |
1196 | // bf16, so mask always & mask for tails |
1197 | || jcp.simd_w == 4; |
1198 | vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag)); |
1199 | } |
1200 | } |
1201 | } else |
1202 | assert(!"unsupported diff_src type" ); |
1203 | } |
1204 | |
1205 | template <typename Vmm> |
1206 | void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::compute_loop( |
1207 | int ur_w, int l_overflow, int r_overflow) { |
1208 | int kw = jcp.kw; |
1209 | int dilate_w = jcp.dilate_w + 1; |
1210 | int stride_w = jcp.stride_w; |
1211 | int stride_h = jcp.stride_h; |
1212 | const int oc_tail = jcp.oc_tail; |
1213 | Label kh_label, skip_compute_label; |
1214 | |
1215 | prepare_output(ur_w); |
1216 | |
1217 | if (jcp.ndims == 5) { |
1218 | mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); |
1219 | cmp(reg_ki, 0); |
1220 | jle(skip_compute_label, T_NEAR); |
1221 | } |
1222 | |
1223 | cmp(reg_kh, 0); |
1224 | jle(skip_compute_label, T_NEAR); |
1225 | |
1226 | // OC loop |
1227 | Label ocb_label; |
1228 | mov(reg_oc, jcp.oc); |
1229 | L(ocb_label); |
1230 | |
1231 | if (jcp.ndims < 5) { |
1232 | mov(aux_reg_dst, reg_dst); |
1233 | mov(aux_reg_ker, reg_ker); |
1234 | } |
1235 | Label kd_label; |
1236 | if (jcp.ndims == 5) { |
1237 | mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); |
1238 | mov(aux_reg_dst_d, reg_dst); |
1239 | mov(aux_reg_ker_d, reg_ker); |
1240 | |
1241 | L(kd_label); |
1242 | mov(aux_reg_dst, aux_reg_dst_d); |
1243 | mov(aux_reg_ker, aux_reg_ker_d); |
1244 | } |
1245 | |
1246 | std::vector<Label> oc_tail_jmp(jcp.kw); |
1247 | mov(reg_kj, reg_kh); |
1248 | L(kh_label); |
1249 | { |
1250 | for (int ki = 0; ki < kw; ki++) { |
1251 | int jj_start = get_iw_start(ki, l_overflow); |
1252 | int jj_end = get_iw_end(ur_w, ki, r_overflow); |
1253 | const int ref_jj_start |
1254 | = nstl::max(0, l_overflow - (kw - 1 - ki) * dilate_w); |
1255 | const int ref_jj_end |
1256 | = ur_w - nstl::max(0, r_overflow - ki * dilate_w); |
1257 | assert(IMPLICATION(stride_w == 1, |
1258 | jj_start == ref_jj_start && jj_end == ref_jj_end)); |
1259 | UNUSED(ref_jj_start); |
1260 | UNUSED(ref_jj_end); |
1261 | const int oc_step = 2; |
1262 | for (int oc = 0; |
1263 | oc < rnd_up(nstl::min(jcp.oc_block, jcp.oc), oc_step); |
1264 | oc += oc_step) { |
1265 | if (oc_tail && oc == rnd_up(oc_tail, oc_step)) { |
1266 | cmp(reg_oc, oc_tail); |
1267 | je(oc_tail_jmp[ki], T_NEAR); |
1268 | } |
1269 | for (int jj = jj_start; jj < jj_end; jj += stride_w) { |
1270 | assert((jj + jcp.l_pad - ki * dilate_w) % stride_w == 0); |
1271 | int ow_idx = (jj + jcp.l_pad - ki * dilate_w) / stride_w; |
1272 | auto aux_ddst_offset = get_diff_dst_offset(ow_idx, oc); |
1273 | auto ddst = vmm_ddst(jj / stride_w); |
1274 | const bool tail_load = oc_tail && oc == rnd_dn(oc_tail, 2); |
1275 | const bool need_single_load = oc + 1 == oc_tail; |
1276 | |
1277 | if (tail_load && need_single_load) { |
1278 | Label mask_load, load_done; |
1279 | cmp(reg_oc, oc + 2); |
1280 | jl(mask_load, T_NEAR); |
1281 | vpbroadcastd(ddst, ptr[aux_reg_dst + aux_ddst_offset]); |
1282 | jmp(load_done, T_NEAR); |
1283 | L(mask_load); |
1284 | // We broadcast w here. As the weights are zero-padded |
1285 | // at oc + 1, vdpbf16ps({0, w}, {dst, dst}) is okay. |
1286 | vpbroadcastw(ddst, ptr[aux_reg_dst + aux_ddst_offset]); |
1287 | L(load_done); |
1288 | } else { |
1289 | vpbroadcastd(ddst, ptr[aux_reg_dst + aux_ddst_offset]); |
1290 | } |
1291 | } |
1292 | for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { |
1293 | size_t aux_kernel_offset = get_kernel_offset(kk, oc, ki); |
1294 | vmovups(vmm_wei, |
1295 | EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); |
1296 | |
1297 | for (int jj = jj_start; jj < jj_end; jj += stride_w) { |
1298 | auto ddst = vmm_ddst(jj / stride_w); |
1299 | auto acc = vmm_dsrc(jj, kk); |
1300 | |
1301 | if (isa_has_bf16(jcp.isa)) { |
1302 | vdpbf16ps(acc, vmm_wei, ddst); |
1303 | } else |
1304 | bf16_emu_->vdpbf16ps(Zmm(acc.getIdx()), |
1305 | Zmm(vmm_wei.getIdx()), Zmm(ddst.getIdx())); |
1306 | } |
1307 | } |
1308 | } |
1309 | L(oc_tail_jmp[ki]); |
1310 | } |
1311 | |
1312 | add(aux_reg_ker, get_kernel_offset(0, 0, 0, stride_h)); |
1313 | sub(aux_reg_dst, get_diff_dst_offset(filter_h_to_dst(1), 0)); |
1314 | |
1315 | dec(reg_kj); |
1316 | cmp(reg_kj, 0); |
1317 | jg(kh_label, T_NEAR); |
1318 | } |
1319 | |
1320 | if (jcp.ndims == 5) { |
1321 | sub(aux_reg_dst_d, get_diff_dst_offset(filter_d_to_dst(1), 0)); |
1322 | add(aux_reg_ker_d, get_kernel_offset(0, 0, 0, 0, jcp.stride_d)); |
1323 | |
1324 | dec(reg_ki); |
1325 | cmp(reg_ki, 0); |
1326 | jg(kd_label, T_NEAR); |
1327 | } |
1328 | |
1329 | // End of OC Loop |
1330 | auto diff_dst_step = get_diff_dst_offset(0, 0, 1); |
1331 | auto ker_step = get_kernel_offset(0, jcp.oc_block, 0); |
1332 | add(reg_dst, diff_dst_step); |
1333 | add(reg_ker, ker_step); |
1334 | |
1335 | sub(reg_oc, jcp.oc_block); |
1336 | cmp(reg_oc, 0); |
1337 | jg(ocb_label, T_NEAR); |
1338 | |
1339 | sub(reg_dst, diff_dst_step * jcp.nb_oc); |
1340 | sub(reg_ker, ker_step * jcp.nb_oc); |
1341 | |
1342 | L(skip_compute_label); |
1343 | store_output(ur_w); |
1344 | } |
1345 | |
1346 | template <typename Vmm> |
1347 | void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::generate() { |
1348 | int iw = jcp.iw; |
1349 | int kw = jcp.kw; |
1350 | int ur_w = jcp.ur_w; |
1351 | int nb_iw = jcp.nb_iw; |
1352 | int iw_block = jcp.iw_block; |
1353 | int ur_w_tail = jcp.ur_w_tail; |
1354 | int dilate_w = jcp.dilate_w + 1; |
1355 | int stride_w = jcp.stride_w; |
1356 | |
1357 | const auto dst_shift = get_diff_dst_offset(ur_w / stride_w, 0); |
1358 | const auto src_shift = get_diff_src_offset(ur_w, 0); |
1359 | |
1360 | preamble(); |
1361 | |
1362 | if (jcp.simd_w == 4) { |
1363 | Reg32 reg_tail_32 = reg_oc.cvt32(); |
1364 | mov(reg_tail_32, (1 << jcp.simd_w) - 1); |
1365 | kmovb(k_ic_tail_mask, reg_tail_32); |
1366 | } |
1367 | |
1368 | if (jcp.ic_tail) { |
1369 | Label done; |
1370 | // dummy mask all 1's |
1371 | if (jcp.simd_w != 4) |
1372 | kxnord(k_ic_tail_mask, k_ic_tail_mask, k_ic_tail_mask); |
1373 | // To account for special store optimization, where two ic_blocks are |
1374 | // combined with one single write, extend the mask for 32bits (32 bf16s) |
1375 | const bool need_extended_mask |
1376 | = isa_has_bf16(jcp.isa) && jcp.nb_ic_blocking > 1; |
1377 | if (need_extended_mask) |
1378 | kxnord(k_ic_tail_mask_extended, k_ic_tail_mask_extended, |
1379 | k_ic_tail_mask_extended); |
1380 | |
1381 | test(byte[param1 + GET_OFF(load_work)], jcp.ic_block - 1); |
1382 | jz(done, T_NEAR); |
1383 | Reg32 reg_tail_32 = reg_ic.cvt32(); |
1384 | mov(reg_tail_32, (1 << jcp.ic_tail) - 1); |
1385 | kmovd(k_ic_tail_mask, reg_tail_32); |
1386 | if (need_extended_mask) { |
1387 | mov(reg_tail_32, (1 << (jcp.ic_tail + jcp.simd_w)) - 1); |
1388 | kmovd(k_ic_tail_mask_extended, reg_tail_32); |
1389 | } |
1390 | L(done); |
1391 | } |
1392 | |
1393 | mov(reg_src, ptr[param + GET_OFF(src)]); |
1394 | mov(reg_dst, ptr[param + GET_OFF(dst)]); |
1395 | mov(reg_ker, ptr[param + GET_OFF(filt)]); |
1396 | |
1397 | mov(reg_kh, ptr[param + GET_OFF(kh_padding)]); |
1398 | |
1399 | int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w); |
1400 | int r_overflow = nstl::max( |
1401 | 0, ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad)) / stride_w); |
1402 | int r_overflow1 = nstl::max(0, |
1403 | ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad + ur_w_tail)) |
1404 | / stride_w); |
1405 | |
1406 | int body_l_overflow = 0, body_r_overflow = 0; |
1407 | int n_oi = iw / ur_w; |
1408 | int head_n_oi = 0, body_n_oi = 0, pretail_n_oi = 0, tail_n_oi = 0; |
1409 | int head_thread = 0, pretail_thread = 0, tail_thread = 0; |
1410 | bool threaded = is_iw_threading_on(jcp); |
1411 | Label head_label, body_label, pretail_label, tail_label, end_label; |
1412 | assert(n_oi > 0); |
1413 | |
1414 | if (r_overflow1 > 0) n_oi--; |
1415 | if (l_overflow > 0) n_oi--; |
1416 | if (n_oi < 0) { |
1417 | // l_overflow and r_overflow1 are handled in the same compute_loop. |
1418 | // Perform one iteration of body handling l_overflow and r_overflow1. |
1419 | body_l_overflow = l_overflow; |
1420 | body_r_overflow = r_overflow1; |
1421 | n_oi = 1; |
1422 | l_overflow = 0; |
1423 | r_overflow1 = 0; |
1424 | } |
1425 | |
1426 | if (!threaded) { |
1427 | if (n_oi > 1) { mov(reg_oi, n_oi); } |
1428 | } else { |
1429 | // Setup for threaded code generation, and jump into the correct |
1430 | // portion of code for execution. |
1431 | head_thread = 0; |
1432 | tail_thread = nb_iw - 1; |
1433 | pretail_thread = tail_thread; |
1434 | |
1435 | int base_n_oi = iw_block / ur_w; |
1436 | head_n_oi = l_overflow > 0 ? base_n_oi - 1 : base_n_oi; |
1437 | tail_n_oi = (iw - iw_block * (nb_iw - 1)) / ur_w; |
1438 | pretail_n_oi = tail_n_oi; |
1439 | if (r_overflow1 > 0) { |
1440 | if (tail_n_oi > 0) { |
1441 | pretail_n_oi--; |
1442 | tail_n_oi = pretail_n_oi; |
1443 | } else { |
1444 | // pretail_thread and tail_thread are different |
1445 | pretail_n_oi = base_n_oi - 1; |
1446 | pretail_thread = tail_thread - 1; |
1447 | } |
1448 | if (head_thread == pretail_thread) { |
1449 | head_n_oi--; |
1450 | pretail_n_oi = 0; |
1451 | tail_n_oi = 0; |
1452 | } |
1453 | } |
1454 | body_n_oi = (head_thread < pretail_thread - 1) ? base_n_oi : 0; |
1455 | |
1456 | // n_oi is used to determine how much control flow in the body portion |
1457 | // of the code needs generated. As such, n_oi needs to be set to the |
1458 | // maximum number of iterations it will be used the body code section. |
1459 | n_oi = nstl::max(body_n_oi, head_n_oi); |
1460 | n_oi = nstl::max(n_oi, pretail_n_oi); |
1461 | |
1462 | assert(iw_block % ur_w == 0); |
1463 | mov(reg_iwb, ptr[param1 + GET_OFF(iwb)]); |
1464 | |
1465 | if (head_n_oi != 0) mov(reg_oi, head_n_oi); |
1466 | cmp(reg_iwb, head_thread); |
1467 | je(head_label, T_NEAR); |
1468 | |
1469 | cmp(reg_iwb, pretail_thread); |
1470 | if (pretail_n_oi == 0) { |
1471 | je(pretail_label, T_NEAR); |
1472 | } else { |
1473 | mov(reg_oi, pretail_n_oi); |
1474 | je(body_label, T_NEAR); |
1475 | } |
1476 | if (pretail_thread != tail_thread) { |
1477 | cmp(reg_iwb, tail_thread); |
1478 | je(tail_label, T_NEAR); |
1479 | } |
1480 | if (body_n_oi != 0) { |
1481 | mov(reg_oi, body_n_oi); |
1482 | jmp(body_label, T_NEAR); |
1483 | } else { |
1484 | jmp(end_label, T_NEAR); |
1485 | } |
1486 | } |
1487 | L(head_label); |
1488 | if (l_overflow > 0) { |
1489 | compute_loop(ur_w, l_overflow, 0); |
1490 | if (threaded && head_n_oi == 0 && head_thread != pretail_thread) |
1491 | jmp(end_label, T_NEAR); |
1492 | add(reg_src, src_shift); |
1493 | add(reg_dst, dst_shift); |
1494 | } |
1495 | L(body_label); |
1496 | if (n_oi > 0) { |
1497 | Label ow_loop_label; |
1498 | L(ow_loop_label); |
1499 | { |
1500 | compute_loop(ur_w, body_l_overflow, body_r_overflow); |
1501 | if (n_oi > 1 || r_overflow1 > 0 || ur_w_tail != 0) { |
1502 | add(reg_src, src_shift); |
1503 | add(reg_dst, dst_shift); |
1504 | } |
1505 | if (n_oi > 1) { |
1506 | sub(reg_oi, 1); |
1507 | jg(ow_loop_label, T_NEAR); |
1508 | } |
1509 | } |
1510 | } |
1511 | if (threaded) { |
1512 | cmp(reg_iwb, pretail_thread); |
1513 | jne(end_label, T_NEAR); |
1514 | } |
1515 | L(pretail_label); |
1516 | if (r_overflow1 > 0) { |
1517 | compute_loop(ur_w, 0, r_overflow1); |
1518 | if (ur_w_tail != 0) { |
1519 | if (threaded && tail_thread != pretail_thread) |
1520 | jmp(end_label, T_NEAR); |
1521 | else { |
1522 | add(reg_src, src_shift); |
1523 | add(reg_dst, dst_shift); |
1524 | } |
1525 | } |
1526 | } |
1527 | L(tail_label); |
1528 | if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_overflow); } |
1529 | L(end_label); |
1530 | |
1531 | postamble(); |
1532 | } |
1533 | |
1534 | status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp, |
1535 | const convolution_desc_t &cd, memory_desc_t &diff_src_md, |
1536 | memory_desc_t &weights_md, memory_desc_t &diff_dst_md, int nthreads) { |
1537 | |
1538 | const memory_desc_wrapper diff_src_d(&diff_src_md); |
1539 | const memory_desc_wrapper weights_d(&weights_md); |
1540 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
1541 | |
1542 | const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; |
1543 | int ndims = diff_src_d.ndims(); |
1544 | |
1545 | jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16 |
1546 | : bf16_emulation_t::get_isa(); |
1547 | jcp.nthr = nthreads; |
1548 | jcp.has_vnni = true; |
1549 | jcp.ndims = ndims; |
1550 | jcp.prop_kind = cd.prop_kind; |
1551 | |
1552 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1553 | jcp.mb = diff_src_d.dims()[0]; |
1554 | |
1555 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
1556 | jcp.oc_without_padding = jcp.oc; |
1557 | jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; |
1558 | |
1559 | jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; |
1560 | jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims - 2]; |
1561 | jcp.iw = diff_src_d.dims()[ndims - 1]; |
1562 | jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; |
1563 | jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2]; |
1564 | jcp.ow = diff_dst_d.dims()[ndims - 1]; |
1565 | |
1566 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
1567 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
1568 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
1569 | |
1570 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
1571 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
1572 | jcp.l_pad = cd.padding[0][ndims - 3]; |
1573 | |
1574 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
1575 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
1576 | jcp.stride_w = cd.strides[ndims - 3]; |
1577 | |
1578 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
1579 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
1580 | jcp.dilate_w = cd.dilates[ndims - 3]; |
1581 | jcp.dst_dt = cd.diff_src_desc.data_type; |
1582 | jcp.nb_iw = 1; |
1583 | jcp.iw_block = jcp.iw; |
1584 | |
1585 | /* Dilated convolutions supported with unit strides only */ |
1586 | if ((jcp.dilate_w != 0 && jcp.stride_w != 1) |
1587 | || (jcp.dilate_d != 0 && jcp.stride_d != 1) |
1588 | || (jcp.dilate_h != 0 && jcp.stride_h != 1)) |
1589 | return status::unimplemented; |
1590 | |
1591 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
1592 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
1593 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
1594 | jcp.r_pad = calculate_end_padding( |
1595 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
1596 | jcp.b_pad = calculate_end_padding( |
1597 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
1598 | jcp.back_pad = calculate_end_padding( |
1599 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); |
1600 | bool kernel_outside_src = false || ext_kw <= jcp.l_pad |
1601 | || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad |
1602 | || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; |
1603 | if (kernel_outside_src) return status::unimplemented; |
1604 | |
1605 | jcp.aligned_threads = 0; |
1606 | |
1607 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
1608 | const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c); |
1609 | const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); |
1610 | const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); |
1611 | auto curr_src_tag = diff_src_d.matches_one_of_tag( |
1612 | dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); |
1613 | auto curr_dst_tag = diff_dst_d.matches_one_of_tag( |
1614 | dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); |
1615 | bool is_data_layout_nxc |
1616 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
1617 | diff_src_d.format_kind() == format_kind::any) |
1618 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
1619 | diff_dst_d.format_kind() == format_kind::any) |
1620 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
1621 | |
1622 | bool ok_to_pad_channels = jcp.ngroups == 1 && !is_data_layout_nxc; |
1623 | |
1624 | jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
1625 | |
1626 | const bool ok_to_try_lower_zmm = true |
1627 | && IMPLICATION(is_data_layout_nxc, |
1628 | jcp.oc < jcp.simd_w && jcp.ic < jcp.simd_w |
1629 | && jcp.ngroups > 1) |
1630 | && !ok_to_pad_channels |
1631 | && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0); |
1632 | |
1633 | if (ok_to_try_lower_zmm) { |
1634 | for (auto simd : {8, 4}) { |
1635 | if (jcp.ic % simd == 0 && jcp.oc % simd == 0) { |
1636 | jcp.simd_w = simd; |
1637 | break; |
1638 | } |
1639 | } |
1640 | } |
1641 | |
1642 | jcp.oc_block = jcp.simd_w; |
1643 | jcp.ic_block = jcp.simd_w; |
1644 | |
1645 | if (ok_to_pad_channels) { |
1646 | jcp.oc = rnd_up(jcp.oc, jcp.oc_block); |
1647 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
1648 | } |
1649 | |
1650 | if (!IMPLICATION(!is_data_layout_nxc, |
1651 | jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0)) |
1652 | return status::unimplemented; |
1653 | jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0; |
1654 | jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.simd_w : 0; |
1655 | |
1656 | format_tag_t wei_tag, dat_tag; |
1657 | |
1658 | if (jcp.simd_w == 8) { |
1659 | dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; |
1660 | wei_tag = utils::pick(ndims - 3, gOIw4o8i2o, gOIhw4o8i2o, gOIdhw4o8i2o); |
1661 | } else if (jcp.simd_w == 4) { |
1662 | dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx4c; |
1663 | wei_tag = utils::pick(ndims - 3, gOIw2o4i2o, gOIhw2o4i2o, gOIdhw2o4i2o); |
1664 | } else { |
1665 | dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
1666 | wei_tag = pick(2 * ndims - 6 + with_groups, OIw8o16i2o, gOIw8o16i2o, |
1667 | OIhw8o16i2o, gOIhw8o16i2o, OIdhw8o16i2o, gOIdhw8o16i2o); |
1668 | } |
1669 | |
1670 | if (diff_src_md.format_kind == format_kind::any) { |
1671 | CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag)); |
1672 | } else if (curr_src_tag != dat_tag) |
1673 | return status::unimplemented; |
1674 | jcp.src_tag = dat_tag; |
1675 | |
1676 | if (diff_dst_md.format_kind == format_kind::any) { |
1677 | CHECK(memory_desc_init_by_tag(diff_dst_md, dat_tag)); |
1678 | } else if (curr_dst_tag != dat_tag) |
1679 | return status::unimplemented; |
1680 | jcp.dst_tag = dat_tag; |
1681 | |
1682 | if (weights_md.format_kind == format_kind::any) { |
1683 | CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); |
1684 | jcp.wei_tag = wei_tag; |
1685 | } else { |
1686 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
1687 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
1688 | } |
1689 | |
1690 | bool args_ok = true && jcp.ic <= diff_src_d.padded_dims()[1] |
1691 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
1692 | && jcp.ic <= weights_d.padded_dims()[with_groups + 1] |
1693 | && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; |
1694 | if (!args_ok) return status::unimplemented; |
1695 | |
1696 | jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block); |
1697 | jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block); |
1698 | |
1699 | jcp.ur_w = jcp.stride_w; |
1700 | |
1701 | /* Maximum number of registers available for result accumulation and delta |
1702 | dst data. One additional register is reserved for weights data. */ |
1703 | const int max_regs |
1704 | = isa_has_bf16(jcp.isa) ? 31 : 26; /* In case of cpx emulation |
1705 | additional 5 registers are |
1706 | reserved */ |
1707 | int l_overflow = nstl::max( |
1708 | 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); |
1709 | |
1710 | jcp.typesize_in = types::data_type_size(diff_dst_d.data_type()); |
1711 | jcp.typesize_out = types::data_type_size(diff_src_d.data_type()); |
1712 | |
1713 | /* Find the best blocking with maximum number of compute instructions |
1714 | per ur_w * nb_ic_blocking compute loops. Number of required registers |
1715 | is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs. |
1716 | ur_w must be divisible by stride_w */ |
1717 | if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers |
1718 | distribution exceeds max_regs */ |
1719 | return status::unimplemented; |
1720 | |
1721 | jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; |
1722 | { |
1723 | jcp.kernel_kind = expl_bcast; |
1724 | int best_compute_pipeline_length = 0; |
1725 | const int max_ic_blocks = 4; |
1726 | for (int b = 1; b <= max_ic_blocks; b++) { |
1727 | if (jcp.nb_ic % b != 0) continue; |
1728 | |
1729 | for (int u = jcp.stride_w; u * b + u / jcp.stride_w <= max_regs |
1730 | && u < jcp.iw + jcp.stride_w; |
1731 | u += jcp.stride_w) { |
1732 | int ur_w = nstl::min(u, jcp.iw); |
1733 | /* maximum 1 step with l_overflow so far */ |
1734 | if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw) |
1735 | continue; |
1736 | int pipeline_length = utils::div_up(ur_w, jcp.stride_w) * b; |
1737 | if (pipeline_length > best_compute_pipeline_length |
1738 | || (pipeline_length == best_compute_pipeline_length |
1739 | && jcp.ur_w < ur_w)) { |
1740 | jcp.ur_w = ur_w; |
1741 | jcp.nb_ic_blocking = b; |
1742 | best_compute_pipeline_length = pipeline_length; |
1743 | } |
1744 | } |
1745 | } |
1746 | if (best_compute_pipeline_length == 0) /* can't find |
1747 | appropriate blocking */ |
1748 | return status::unimplemented; |
1749 | } |
1750 | jcp.ur_w_tail = jcp.iw % jcp.ur_w; |
1751 | |
1752 | if (is_iw_threading_available(jcp)) { |
1753 | int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; |
1754 | int work_units = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; |
1755 | float no_iw_block_eff |
1756 | = (float)work_units / rnd_up(work_units, jcp.nthr); |
1757 | |
1758 | // current design of generate() requires iw_block >= 2 * ur_w |
1759 | const int min_iw_block = jcp.ur_w * 2; |
1760 | int iw_threads = jcp.nthr / math::gcd(work_units, jcp.nthr); |
1761 | int iw_block = nstl::max(min_iw_block, |
1762 | rnd_up(jcp.iw, jcp.ur_w * iw_threads) / iw_threads); |
1763 | int nb_iw = div_up(jcp.iw, iw_block); |
1764 | |
1765 | float block_eff = (float)jcp.iw / rnd_up(jcp.iw, iw_block); |
1766 | work_units = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih * nb_iw; |
1767 | float work_eff = (float)work_units / rnd_up(work_units, jcp.nthr); |
1768 | float iw_block_eff = block_eff * work_eff; |
1769 | |
1770 | const int iw_thread_min_size = 16 * 128; |
1771 | const float iw_block_cost = 20.0; |
1772 | float block_overhead = nstl::max(0.0f, 1.0f - iw_block_cost / iw_block); |
1773 | |
1774 | bool iw_thread_useful = no_iw_block_eff < block_overhead * iw_block_eff |
1775 | && jcp.ic_block * jcp.iw > iw_thread_min_size; |
1776 | |
1777 | if (iw_thread_useful) { |
1778 | jcp.iw_block = iw_block; |
1779 | jcp.nb_iw = nb_iw; |
1780 | } |
1781 | } |
1782 | |
1783 | if (l_overflow * jcp.stride_w > jcp.ur_w) return status::unimplemented; |
1784 | int r_overflow_no_tail = nstl::max(0, |
1785 | ((jcp.kw - 1) * (jcp.dilate_w + 1) |
1786 | - nstl::max(0, jcp.r_pad + jcp.ur_w_tail)) |
1787 | / jcp.stride_w); |
1788 | bool tails_not_ok = false |
1789 | /* maximum 1 ur_w block with r_overflow so far */ |
1790 | || r_overflow_no_tail * jcp.stride_w > jcp.ur_w |
1791 | /* ur_w must be a multiple of stride */ |
1792 | || ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) |
1793 | /* r_pad must not extend beyond ur_w_tail */ |
1794 | || ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0)); |
1795 | if (tails_not_ok) return status::unimplemented; |
1796 | |
1797 | /* adjust the thread decomposition |
1798 | * to improve the perf for small problem size |
1799 | * the threshold L1_cache_size/factor and the factor is empirical |
1800 | * simply set the thread number to 4 now |
1801 | * TODO: Add get_thr_eff function to compute optimal thread*/ |
1802 | size_t wei_size = (size_t)sizeof(bfloat16_t) * jcp.ic * jcp.oc * jcp.kh |
1803 | * jcp.kw * jcp.kd; |
1804 | size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih |
1805 | * jcp.iw * jcp.id; |
1806 | size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh |
1807 | * jcp.ow * jcp.od; |
1808 | size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size); |
1809 | const unsigned int L1_cache_size = platform::get_per_core_cache_size(1); |
1810 | |
1811 | //The factor for 1d: 1, 2d: 2, 3d: 4; |
1812 | int factor = nstl::max(1, (2 * (ndims - 3))); |
1813 | if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size / factor) { |
1814 | jcp.nthr = nstl::min(jcp.nthr, 4); |
1815 | } |
1816 | |
1817 | pick_loop_order(jcp); |
1818 | |
1819 | return status::success; |
1820 | } |
1821 | |
1822 | const int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::max_ur_w = 28; |
1823 | |
1824 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
1825 | od_step_comeback_pointers() { |
1826 | Label kd_comeback_label; |
1827 | mov(kj, reg_kd_count); |
1828 | L(kd_comeback_label); |
1829 | { |
1830 | sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1))); |
1831 | sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw)); |
1832 | dec(kj); |
1833 | cmp(kj, 0); |
1834 | jg(kd_comeback_label, T_NEAR); |
1835 | } |
1836 | } |
1837 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
1838 | oh_step_comeback_pointers() { |
1839 | Label kh_comeback_label; |
1840 | mov(kj, reg_kh); |
1841 | L(kh_comeback_label); |
1842 | { |
1843 | sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1))); |
1844 | sub(reg_kernel, get_kernel_offset(0, jcp.kw)); |
1845 | dec(kj); |
1846 | cmp(kj, 0); |
1847 | jg(kh_comeback_label, T_NEAR); |
1848 | } |
1849 | } |
1850 | |
1851 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
1852 | compute_ic_block_step_extern(int ur_w, int pad_l, int pad_r, |
1853 | int ic_block_step, int src_offset, int kernel_offset, |
1854 | int ddst_offset, bool is_tail) { |
1855 | assert(!is_src_layout_nxc() && !is_ddst_layout_nxc()); |
1856 | int kw = jcp.kw; |
1857 | bool no_src_pad = jcp.is_1stconv && !jcp.transpose_src; |
1858 | const int ddst_zmm_base_idx = 24; |
1859 | const int num_ddst_zmm_regs = !isa_has_bf16(jcp.isa) ? 2 : 4; |
1860 | const int zmm_src_reg = ddst_zmm_base_idx + num_ddst_zmm_regs; |
1861 | |
1862 | auto zmm_ker = [=](int i_kw, int i_ic) { |
1863 | return Zmm(i_kw * ic_block_step + i_ic); |
1864 | }; |
1865 | auto zmm_ddst = [=](int i_iw) { |
1866 | // TODO: move reg calc to global member funcs |
1867 | return Zmm(ddst_zmm_base_idx + i_iw % num_ddst_zmm_regs); |
1868 | }; |
1869 | |
1870 | auto ker_addr = [=](int i_kw, int i_ic) { |
1871 | auto local_offset = get_kernel_offset(i_ic, i_kw); |
1872 | return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset); |
1873 | }; |
1874 | auto src_addr = [=](int i_iw, int i_ic, ptrdiff_t = 0, |
1875 | bool vnni_bcast = false) { |
1876 | auto local_offset = get_src_offset(i_ic, i_iw); |
1877 | return EVEX_compress_addr( |
1878 | reg_src, local_offset + src_offset + extra_offset, vnni_bcast); |
1879 | }; |
1880 | auto ddst_addr = [=](int i_ur) { |
1881 | auto ow_scale = 2; |
1882 | return EVEX_compress_addr( |
1883 | reg_ddst, get_ddst_offset(ow_scale * i_ur) + ddst_offset); |
1884 | }; |
1885 | |
1886 | for (int i_kw = 0; i_kw < kw; i_kw++) |
1887 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
1888 | auto zmm = zmm_ker(i_kw, i_ic); |
1889 | vpxord(zmm, zmm, zmm); |
1890 | } |
1891 | assert(ur_w % 2 == 0); |
1892 | auto steps = ur_w / 2; |
1893 | |
1894 | const int str_w = jcp.stride_w; |
1895 | const int underflow_boundary = -1; |
1896 | int i_iw_shift = jcp.tr_ow - ur_w - ((jcp.l_pad != pad_l) ? jcp.l_pad : 0); |
1897 | const int overflow_boundary = jcp.iw - 1 - i_iw_shift; |
1898 | |
1899 | for (int s = 0; s < str_w; s++) { |
1900 | const int kw_start = s; |
1901 | assert(jcp.tr_iw % str_w == 0); |
1902 | const int src_stride_w_shift = jcp.tr_iw / str_w; |
1903 | for (int i_ur = 0; i_ur < steps; i_ur++) { |
1904 | auto zmm = zmm_ddst(i_ur); |
1905 | vmovdqu16(zmm, ddst_addr(i_ur)); |
1906 | |
1907 | for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) { |
1908 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
1909 | int i_iw = 2 * i_ur + (i_kw * (jcp.dilate_w + 1)) / str_w |
1910 | + s * src_stride_w_shift; |
1911 | bool underflow = false; |
1912 | bool overflow = false; |
1913 | if (no_src_pad) { |
1914 | i_iw = i_iw - pad_l; |
1915 | underflow = i_iw <= underflow_boundary; |
1916 | overflow = is_tail && i_iw >= overflow_boundary; |
1917 | } |
1918 | |
1919 | auto src = Zmm(zmm_src_reg); |
1920 | auto acc = zmm_ker(i_kw, i_ic); |
1921 | auto ddst = zmm_ddst(i_ur); |
1922 | if (underflow || overflow || !isa_has_bf16(jcp.isa)) { |
1923 | assert(ddst != src); |
1924 | assert(acc != src); |
1925 | } |
1926 | assert(ddst != acc); |
1927 | if (underflow || overflow) { |
1928 | if (underflow && i_iw == underflow_boundary) |
1929 | vpbroadcastw(src | everyother_shift_mask | T_z, |
1930 | src_addr(i_iw + 1, i_ic, 0)); |
1931 | else if (overflow && i_iw == overflow_boundary) |
1932 | vpbroadcastw(src | everyother_mask | T_z, |
1933 | src_addr(i_iw, i_ic, 0)); |
1934 | else |
1935 | continue; |
1936 | |
1937 | if (!isa_has_bf16(jcp.isa)) |
1938 | bf16_emu_->vdpbf16ps(acc, ddst, src); |
1939 | else |
1940 | vdpbf16ps(acc, ddst, src); |
1941 | } else if (!isa_has_bf16(jcp.isa)) { |
1942 | vpbroadcastd(src, src_addr(i_iw, i_ic, 0)); |
1943 | bf16_emu_->vdpbf16ps(acc, ddst, src); |
1944 | } else |
1945 | vdpbf16ps(acc, ddst, src_addr(i_iw, i_ic, 0, true)); |
1946 | } |
1947 | } |
1948 | } |
1949 | for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) { |
1950 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
1951 | auto addr = ker_addr(i_kw, i_ic); |
1952 | auto zmm = zmm_ker(i_kw, i_ic); |
1953 | vaddps(zmm, zmm, addr); |
1954 | vmovups(addr, zmm); |
1955 | } |
1956 | } |
1957 | } |
1958 | } |
1959 | |
1960 | int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::interleave_w_reorder_size( |
1961 | int ur_w) const { |
1962 | const int reorder_block = 16; |
1963 | return rnd_up(jcp.stride_w * (ur_w - 1) + jcp.kw, reorder_block); |
1964 | } |
1965 | int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
1966 | interleave_w_reorder_bytes(int ur_w) { |
1967 | return 2 * jcp.typesize_in * interleave_w_reorder_size(ur_w); |
1968 | } |
1969 | int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::interleave_stack_size( |
1970 | int ur_w, int ic_block_step) { |
1971 | return ic_block_step * interleave_w_reorder_bytes(ur_w); |
1972 | } |
1973 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
1974 | compute_ic_block_step_interleave(int ur_w, int pad_l, int pad_r, |
1975 | int ic_block_step, int src_offset, int kernel_offset, |
1976 | int ddst_offset, bool is_tail) { |
1977 | // Only supports nchw format src |
1978 | assert(jcp.is_1stconv && !jcp.transpose_src); |
1979 | int kw = jcp.kw; |
1980 | const int ddst_zmm_base_idx = 24; |
1981 | const int in_zmm_base_idx = 24; |
1982 | const int num_ddst_zmm_regs = !isa_has_bf16(jcp.isa) ? 2 : 4; |
1983 | //const int num_in_zmm_regs = 8; |
1984 | const int zmm_src_reg = ddst_zmm_base_idx + num_ddst_zmm_regs; |
1985 | const int reorder_block = 16; |
1986 | const int reorder_size = interleave_w_reorder_size(ur_w); |
1987 | const int reorder_bytes = interleave_w_reorder_bytes(ur_w); |
1988 | const int stack_size = interleave_stack_size(ur_w, ic_block_step); |
1989 | if (stack_size > ic_block_step_stack_size) { |
1990 | // This is a guard. Ideally it is never used, but is included to defend |
1991 | // against overlooked edge cases. |
1992 | assert(stack_size <= ic_block_step_stack_size); |
1993 | sub(rsp, stack_size - ic_block_step_stack_size); |
1994 | } |
1995 | |
1996 | auto zmm_ker = [=](int i_kw, int i_ic) { |
1997 | return Zmm(i_kw * ic_block_step + i_ic); |
1998 | }; |
1999 | auto zmm_ddst = [=](int i_iw) { |
2000 | return Zmm(ddst_zmm_base_idx + i_iw % num_ddst_zmm_regs); |
2001 | }; |
2002 | auto zmm_in = [=](int i_iw, int i_ic, bool stride_reg) { |
2003 | int stride = stride_reg ? 1 : 0; |
2004 | return Zmm(in_zmm_base_idx + 4 * (i_ic % 2) + 2 * (i_iw % 2) + stride); |
2005 | }; |
2006 | |
2007 | auto ker_addr = [=](int i_kw, int i_ic) { |
2008 | auto local_offset = get_kernel_offset(i_ic, i_kw); |
2009 | return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset); |
2010 | }; |
2011 | auto src_addr = [=](int i_iw, int i_ic, ptrdiff_t = 0, |
2012 | bool vnni_bcast = false) { |
2013 | int local_offset = i_ic * reorder_bytes + 2 * jcp.typesize_in * i_iw; |
2014 | return EVEX_compress_addr(rsp, local_offset, vnni_bcast); |
2015 | }; |
2016 | auto ddst_addr = [=](int i_ur) { |
2017 | auto ow_scale = 2; |
2018 | return EVEX_compress_addr( |
2019 | reg_ddst, get_ddst_offset(ow_scale * i_ur) + ddst_offset); |
2020 | }; |
2021 | auto load_src_to_stack = [=](int i_iw, int i_ic, Opmask mask, |
2022 | bool mask_empty, Opmask stride_mask, |
2023 | bool stride_mask_empty) { |
2024 | auto local_offset = get_src_offset(i_ic, i_iw); |
2025 | int stack_offset |
2026 | = i_ic * reorder_bytes + 2 * jcp.typesize_in * (i_iw + pad_l); |
2027 | |
2028 | auto zmm = zmm_in(i_iw, i_ic, false); |
2029 | auto zmm_stride = zmm_in(i_iw, i_ic, true); |
2030 | auto base_addr |
2031 | = EVEX_compress_addr(reg_src, local_offset + src_offset, false); |
2032 | auto stride_addr = EVEX_compress_addr(reg_src, |
2033 | local_offset + src_offset + get_src_offset(0, jcp.stride_w)); |
2034 | auto stack_addr = EVEX_compress_addr(rsp, stack_offset); |
2035 | assert(IMPLICATION(mask_empty, stride_mask_empty)); |
2036 | if (mask_empty) { |
2037 | vpxord(zmm, zmm, zmm); |
2038 | } else { |
2039 | vpmovzxwd(zmm | mask | T_z, base_addr); |
2040 | } |
2041 | if (!stride_mask_empty) { |
2042 | vpmovzxwd(zmm_stride | stride_mask | T_z, stride_addr); |
2043 | vpslld(zmm_stride, zmm_stride, 16); |
2044 | vpord(zmm, zmm, zmm_stride); |
2045 | } |
2046 | vmovdqu16(stack_addr, zmm); |
2047 | }; |
2048 | |
2049 | assert(ur_w % 2 == 0); |
2050 | auto steps = ur_w / 2; |
2051 | |
2052 | const int str_w = jcp.stride_w; |
2053 | int i_iw_shift = str_w * (jcp.tr_ow - ur_w) |
2054 | - ((jcp.l_pad != pad_l) ? jcp.l_pad : 0); |
2055 | const int overflow_boundary |
2056 | = is_tail ? jcp.iw - i_iw_shift : str_w * (ur_w - 1) + kw - pad_l; |
2057 | |
2058 | // Calculate padding required by the data reorder using 32 byte loads |
2059 | int reorder_overflow = reorder_size - pad_l - overflow_boundary; |
2060 | int reorder_stride_overflow = reorder_overflow + str_w; |
2061 | reorder_overflow = nstl::max(0, reorder_overflow); |
2062 | reorder_stride_overflow = nstl::max(0, reorder_stride_overflow); |
2063 | int reorder_pad_r = reorder_overflow % reorder_block; |
2064 | int reorder_stride_pad_r = reorder_stride_overflow % reorder_block; |
2065 | if (reorder_stride_overflow >= reorder_size && reorder_stride_pad_r == 0) { |
2066 | assert(reorder_stride_overflow == reorder_size); |
2067 | reorder_stride_pad_r = reorder_block; |
2068 | } |
2069 | reorder_overflow -= reorder_pad_r; |
2070 | reorder_stride_overflow -= reorder_stride_pad_r; |
2071 | |
2072 | int pad_l_mask = (0xffff << pad_l) & 0xffff; |
2073 | int pad_l_mask_strided |
2074 | = (0xffff << (pad_l >= str_w ? (pad_l - str_w) : 0)) & 0xffff; |
2075 | int pad_r_mask = 0xffff >> reorder_pad_r; |
2076 | int pad_r_mask_strided = 0xffff >> (reorder_stride_pad_r); |
2077 | pad_r_mask = pad_r_mask & 0xffff; |
2078 | |
2079 | // Setup masks to load and reorder data |
2080 | if (reorder_size - reorder_stride_overflow > reorder_block) { |
2081 | // Overflow and underflow happen in different data reorder rounds |
2082 | kxnorw(overflow_stride_mask, overflow_stride_mask, |
2083 | overflow_stride_mask); |
2084 | kshiftlw(underflow_mask, overflow_stride_mask, pad_l); |
2085 | kshiftlw(underflow_stride_mask, overflow_stride_mask, |
2086 | pad_l >= str_w ? pad_l - str_w : 0); |
2087 | kshiftrw(overflow_mask, overflow_stride_mask, reorder_pad_r); |
2088 | kshiftrw(overflow_stride_mask, overflow_stride_mask, |
2089 | reorder_stride_pad_r); |
2090 | } else if (reorder_size - reorder_overflow > reorder_block) { |
2091 | // Overflow and underflow happen in the same round for loading the data |
2092 | // at the stride offset. |
2093 | kxnorw(overflow_mask, overflow_mask, overflow_mask); |
2094 | kshiftlw(underflow_mask, overflow_mask, pad_l); |
2095 | kshiftrw(overflow_mask, overflow_mask, reorder_pad_r); |
2096 | mov(reg_tmp.cvt32(), pad_l_mask_strided & pad_r_mask_strided); |
2097 | kmovw(underflow_stride_mask, reg_tmp.cvt32()); |
2098 | } else { |
2099 | // Overflow and underflow happen in the same round for all data loads |
2100 | mov(reg_tmp.cvt32(), pad_l_mask & pad_r_mask); |
2101 | kmovw(underflow_mask, reg_tmp.cvt32()); |
2102 | mov(reg_tmp.cvt32(), pad_l_mask_strided & pad_r_mask_strided); |
2103 | kmovw(underflow_stride_mask, reg_tmp.cvt32()); |
2104 | } |
2105 | |
2106 | // Load and reorder data to the stack |
2107 | int reorder_start = -pad_l; |
2108 | int reorder_end = reorder_size - pad_l; |
2109 | for (int i_iw = reorder_start; i_iw < reorder_end; i_iw += reorder_block) { |
2110 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2111 | Opmask mask, stride_mask; |
2112 | bool mask_empty, stride_mask_empty; |
2113 | // Performing this reorder on the stack may not be (always) optimal. |
2114 | // There are a couple of methods involving externally reordering the |
2115 | // data that were not considered due to time constraints. The first |
2116 | // is to transpose similar to the extern method. The other is to |
2117 | // perform the same interleave transform used here. The tradeoff |
2118 | // between these methods is the transpose method does not lend |
2119 | // itself to SIMD instructions (except possibly for some specific |
2120 | // strides) since the data is not blocked. The transform performed |
2121 | // here does, but uses twice as much data since |
2122 | // most data elements are duplicated. |
2123 | |
2124 | if (i_iw == reorder_start) { |
2125 | mask = underflow_mask; |
2126 | mask_empty = false; |
2127 | if (pad_l_mask == 0) mask_empty = true; |
2128 | } else if (i_iw + reorder_overflow >= reorder_end) { |
2129 | mask_empty = true; |
2130 | } else if (i_iw + reorder_block + reorder_overflow >= reorder_end) { |
2131 | mask = overflow_mask; |
2132 | mask_empty = false; |
2133 | if (pad_r_mask == 0) mask_empty = true; |
2134 | } else { |
2135 | mask = m_ffffffff; |
2136 | mask_empty = false; |
2137 | } |
2138 | if (i_iw == reorder_start) { |
2139 | stride_mask = underflow_stride_mask; |
2140 | stride_mask_empty = false; |
2141 | if (pad_l_mask_strided == 0) mask_empty = true; |
2142 | } else if (i_iw + reorder_stride_overflow >= reorder_end) { |
2143 | stride_mask_empty = true; |
2144 | } else if (i_iw + reorder_block + reorder_stride_overflow |
2145 | >= reorder_end) { |
2146 | stride_mask = overflow_stride_mask; |
2147 | stride_mask_empty = false; |
2148 | if (pad_r_mask_strided == 0) mask_empty = true; |
2149 | } else { |
2150 | stride_mask = m_ffffffff; |
2151 | stride_mask_empty = false; |
2152 | } |
2153 | load_src_to_stack(i_iw, i_ic, mask, mask_empty, stride_mask, |
2154 | stride_mask_empty); |
2155 | } |
2156 | } |
2157 | |
2158 | // Initialize kernel accumulators. It should sometimes be possible to skip |
2159 | // initializing and storing this data between calls to this function. |
2160 | for (int i_kw = 0; i_kw < kw; i_kw++) |
2161 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2162 | auto zmm = zmm_ker(i_kw, i_ic); |
2163 | vpxord(zmm, zmm, zmm); |
2164 | } |
2165 | |
2166 | // Calculate this blocks contribution |
2167 | for (int i_ur = 0; i_ur < steps; i_ur++) { |
2168 | auto zmm = zmm_ddst(i_ur); |
2169 | vmovdqu16(zmm, ddst_addr(i_ur)); |
2170 | |
2171 | for (int i_kw = 0; i_kw < kw; i_kw++) { |
2172 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2173 | int i_iw = 2 * i_ur * str_w + i_kw; |
2174 | auto acc = zmm_ker(i_kw, i_ic); |
2175 | auto ddst = zmm_ddst(i_ur); |
2176 | |
2177 | const bool isa_supports_bf16 = isa_has_bf16(jcp.isa); |
2178 | auto src_stack_addr |
2179 | = src_addr(i_iw, i_ic, 0, isa_supports_bf16); |
2180 | |
2181 | if (isa_supports_bf16) |
2182 | vdpbf16ps(acc, ddst, src_stack_addr); |
2183 | else { |
2184 | auto src = Zmm(zmm_src_reg); |
2185 | vpbroadcastd(src, src_stack_addr); |
2186 | bf16_emu_->vdpbf16ps(acc, ddst, src); |
2187 | } |
2188 | } |
2189 | } |
2190 | } |
2191 | |
2192 | // Store kernel accumulators |
2193 | for (int i_kw = 0; i_kw < kw; i_kw++) { |
2194 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2195 | auto addr = ker_addr(i_kw, i_ic); |
2196 | auto zmm = zmm_ker(i_kw, i_ic); |
2197 | vaddps(zmm, zmm, addr); |
2198 | vmovups(addr, zmm); |
2199 | } |
2200 | } |
2201 | |
2202 | if (stack_size > ic_block_step_stack_size) { |
2203 | // This is a guard. Ideally it is never used, but is included to defend |
2204 | // against overlooked edge cases. |
2205 | add(rsp, stack_size - ic_block_step_stack_size); |
2206 | } |
2207 | } |
2208 | |
2209 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
2210 | convert_src_to_vnni_format( |
2211 | int ur_w, int pad_l, int pad_r, int src_offset) { |
2212 | Reg64 reg_trans_tmp = r11; |
2213 | const int ic_tail = jcp.ic_tail; |
2214 | mov(EVEX_compress_addr(rsp, trans_tmp_offset), reg_trans_tmp); |
2215 | |
2216 | mov(reg_trans_tmp, dst_prm_table); |
2217 | vmovups(get_perm_reg(), ptr[reg_trans_tmp]); |
2218 | |
2219 | mov(reg_trans_tmp, EVEX_compress_addr(rsp, trans_tmp_offset)); |
2220 | const int max_regs = 16; |
2221 | if (ic_tail) { |
2222 | Label skip_tail_mask; |
2223 | cmp(reg_icb, jcp.simd_w); |
2224 | jge(skip_tail_mask); |
2225 | kandd(m_0000ffff, m_0000ffff, m_0000_ic_tail); |
2226 | kandd(m_ffff0000, m_ffff0000, m_ic_tail_0000); |
2227 | L(skip_tail_mask); |
2228 | } |
2229 | for (int src_count = 0; |
2230 | sizeof_cacheline * src_count < permw_stack_size(ur_w); |
2231 | src_count++) { |
2232 | int i_ur = nstl::min(src_count, ur_w - 2); |
2233 | int i_kw = src_count - i_ur; |
2234 | int buffer_offset = permw_buffer_start + src_count * 64; |
2235 | auto bcast_values = Zmm(src_count % max_regs); |
2236 | bool check = check_borders(ur_w, pad_l, pad_r, i_ur, i_kw); |
2237 | if (check) { |
2238 | if (is_src_layout_nxc()) { |
2239 | int iw_1, iw_2; |
2240 | get_w_positions(ur_w, pad_l, pad_r, i_ur, i_kw, iw_1, iw_2); |
2241 | if (iw_1 == -1) |
2242 | vxorpd(bcast_values, bcast_values, bcast_values); |
2243 | else { |
2244 | dim_t local_src_offset = src_offset |
2245 | + get_src_offset( |
2246 | 0, filter_w_to_src(i_kw, i_ur, pad_l)); |
2247 | vmovdqu16(bcast_values | m_0000ffff | T_z, |
2248 | ptr[reg_src + local_src_offset]); |
2249 | } |
2250 | if (iw_2 != -1) { |
2251 | dim_t local_src_offset = src_offset - 32 |
2252 | + get_src_offset( |
2253 | 0, filter_w_to_src(i_kw, i_ur + 1, pad_l)); |
2254 | vmovdqu16(bcast_values | m_ffff0000, |
2255 | ptr[reg_src + local_src_offset]); |
2256 | } |
2257 | } else { |
2258 | Opmask load_mask; |
2259 | get_load_mask(ur_w, pad_l, pad_r, i_ur, i_kw, load_mask); |
2260 | |
2261 | dim_t local_src_offset = src_offset |
2262 | + get_src_offset(0, filter_w_to_src(i_kw, i_ur, pad_l)); |
2263 | vmovdqu16(bcast_values | load_mask | T_z, |
2264 | ptr[reg_src + local_src_offset]); |
2265 | } |
2266 | vpermw(bcast_values, get_perm_reg(), bcast_values); |
2267 | } else { |
2268 | vpxord(bcast_values, bcast_values, bcast_values); |
2269 | } |
2270 | vmovups(ptr[rsp + buffer_offset], bcast_values); |
2271 | } |
2272 | if (ic_tail) { |
2273 | // Reset-back the masks |
2274 | kxnorw(m_0000ffff, m_0000ffff, m_0000ffff); |
2275 | kshiftld(m_ffff0000, m_0000ffff, 16); |
2276 | } |
2277 | } |
2278 | |
2279 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
2280 | may_be_set_oc_tail_mask() { |
2281 | if (jcp.oc_tail) { |
2282 | Label skip_tail_mask; |
2283 | cmp(dword[param + GET_OFF(load_work)], jcp.simd_w); |
2284 | jge(skip_tail_mask); |
2285 | kandd(m_0000ffff, m_0000ffff, m_0000_oc_tail); |
2286 | kandd(m_ffff0000, m_ffff0000, m_oc_tail_0000); |
2287 | L(skip_tail_mask); |
2288 | } |
2289 | } |
2290 | |
2291 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
2292 | may_be_reset_oc_tail_mask() { |
2293 | if (jcp.oc_tail) { |
2294 | // Reset-back the masks |
2295 | kxnorw(m_0000ffff, m_0000ffff, m_0000ffff); |
2296 | kshiftld(m_ffff0000, m_0000ffff, 16); |
2297 | } |
2298 | } |
2299 | |
2300 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
2301 | compute_ic_block_step_vpermw_expl(int ur_w, int pad_l, int pad_r, |
2302 | int ic_block_step, int src_offset, int kernel_offset, |
2303 | int ddst_offset, bool is_tail) { |
2304 | assert(!jcp.is_1stconv); // This method does not support nchw data |
2305 | int kw = jcp.kw; |
2306 | int src_count = 0; |
2307 | int ic_block_step_idx = src_offset / (jcp.typesize_in * ic_block_step); |
2308 | const int max_regs = (!isa_has_bf16(jcp.isa)) ? 26 : 31; |
2309 | int src_pl_len = kw; |
2310 | const int diff_dst_pl_start_reg_idx = ic_block_step * (kw + src_pl_len); |
2311 | const int diff_dst_pl_len = max_regs - diff_dst_pl_start_reg_idx; |
2312 | |
2313 | auto get_diff_wei_reg_idx |
2314 | = [=](int i_kw, int i_ic) { return i_kw * ic_block_step + i_ic; }; |
2315 | auto get_src_reg_idx = [=](int i_iw, int i_ic) { |
2316 | return kw * ic_block_step + (i_iw % src_pl_len) * ic_block_step + i_ic; |
2317 | }; |
2318 | auto get_diff_dst_reg_idx = [=](int i_ur) { |
2319 | return diff_dst_pl_start_reg_idx + (i_ur / 2) % diff_dst_pl_len; |
2320 | }; |
2321 | |
2322 | may_be_set_oc_tail_mask(); |
2323 | auto load_dst = [=](int c) { |
2324 | bool is_tail = ur_w % 2 && c * 2 + 2 >= ur_w; |
2325 | bool is_ddst_nxc = is_ddst_layout_nxc(); |
2326 | auto offset = get_ddst_offset(c * 2) + ddst_offset; |
2327 | |
2328 | Opmask load_mask = is_ddst_nxc || is_tail ? m_0000ffff : m_ffffffff; |
2329 | vmovdqu16(Zmm(get_diff_dst_reg_idx(2 * c)) | load_mask | T_z, |
2330 | EVEX_compress_addr(reg_ddst, offset)); |
2331 | |
2332 | if (is_ddst_nxc && !is_tail) { |
2333 | offset += get_ddst_offset(1) - 32; |
2334 | vmovdqu16(Zmm(get_diff_dst_reg_idx(2 * c)) | m_ffff0000, |
2335 | EVEX_compress_addr(reg_ddst, offset)); |
2336 | } |
2337 | vpermw(Zmm(get_diff_dst_reg_idx(2 * c)), get_perm_reg(), |
2338 | Zmm(get_diff_dst_reg_idx(2 * c))); |
2339 | }; |
2340 | |
2341 | for (int i_kw = 0; i_kw < kw; i_kw++) |
2342 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) |
2343 | vpxord(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)), |
2344 | Zmm(get_diff_wei_reg_idx(i_kw, i_ic)), |
2345 | Zmm(get_diff_wei_reg_idx(i_kw, i_ic))); |
2346 | |
2347 | auto get_bcast_ptr = [=](int i_ur, int i_kw, int ic) { |
2348 | int scale = 2 * jcp.typesize_in; |
2349 | return rsp + b_ic * scale + permw_buffer_start + (i_ur + i_kw) * 64 |
2350 | + jcp.typesize_in * 2 |
2351 | * (ic_block_step_idx * ic_block_step + ic); |
2352 | }; |
2353 | int src_count_last = 0; |
2354 | for (int i_ur = 0; i_ur < ur_w; i_ur += 2) { |
2355 | if (i_ur == 0) { |
2356 | for (int dst_count = 0; |
2357 | dst_count < nstl::min(diff_dst_pl_len, div_up(ur_w, 2)); |
2358 | dst_count++) { |
2359 | load_dst(dst_count); |
2360 | } |
2361 | for (src_count = 0; src_count < src_pl_len; src_count++) { |
2362 | int _i_ur = src_count / kw; |
2363 | int _i_kw = src_count % kw; |
2364 | if (check_borders(ur_w, pad_l, pad_r, _i_ur, _i_kw)) |
2365 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2366 | vbroadcastss(Zmm(get_src_reg_idx(src_count, i_ic)), |
2367 | ptr[get_bcast_ptr(_i_ur, _i_kw, i_ic)]); |
2368 | } |
2369 | } |
2370 | src_count_last = src_count; |
2371 | } else { |
2372 | int diff_dst_load_idx = i_ur + 2 * (diff_dst_pl_len - 1); |
2373 | if (diff_dst_load_idx < ur_w) load_dst(diff_dst_load_idx / 2); |
2374 | for (src_count = i_ur; src_count < i_ur + src_pl_len; src_count++) { |
2375 | if (src_count < src_count_last) continue; |
2376 | int _i_ur = (src_count - i_ur) / kw + i_ur; |
2377 | int _i_kw = (src_count - i_ur) % kw; |
2378 | if (check_borders(ur_w, pad_l, pad_r, _i_ur, _i_kw)) |
2379 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2380 | vbroadcastss(Zmm(get_src_reg_idx(src_count, i_ic)), |
2381 | ptr[get_bcast_ptr(_i_ur, _i_kw, i_ic)]); |
2382 | } |
2383 | } |
2384 | src_count_last = src_count; |
2385 | } |
2386 | for (int i_kw = 0; i_kw < kw; i_kw++) { |
2387 | int i_iw = i_ur + i_kw; |
2388 | if (check_borders(ur_w, pad_l, pad_r, i_ur, i_kw)) { |
2389 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2390 | if (!isa_has_bf16(jcp.isa)) { |
2391 | bf16_emu_->vdpbf16ps( |
2392 | Zmm(get_diff_wei_reg_idx(i_kw, i_ic)), |
2393 | Zmm(get_diff_dst_reg_idx(i_ur)), |
2394 | Zmm(get_src_reg_idx(i_iw, i_ic))); |
2395 | } else { |
2396 | vdpbf16ps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)), |
2397 | Zmm(get_diff_dst_reg_idx(i_ur)), |
2398 | Zmm(get_src_reg_idx(i_iw, i_ic))); |
2399 | } |
2400 | } |
2401 | } |
2402 | } |
2403 | } |
2404 | |
2405 | for (int i_kw = 0; i_kw < kw; i_kw++) |
2406 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2407 | auto l_offset = get_kernel_offset(i_ic, i_kw); |
2408 | vaddps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)), |
2409 | EVEX_compress_addr(reg_kernel, l_offset + kernel_offset)); |
2410 | } |
2411 | |
2412 | for (int i_kw = 0; i_kw < kw; i_kw++) { |
2413 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2414 | auto l_offset = get_kernel_offset(i_ic, i_kw); |
2415 | vmovups(EVEX_compress_addr(reg_kernel, l_offset + kernel_offset), |
2416 | Zmm(get_diff_wei_reg_idx(i_kw, i_ic))); |
2417 | } |
2418 | } |
2419 | |
2420 | may_be_reset_oc_tail_mask(); |
2421 | } |
2422 | |
2423 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
2424 | compute_ic_block_step_vpermw(int ur_w, int pad_l, int pad_r, |
2425 | int ic_block_step, int src_offset, int kernel_offset, |
2426 | int ddst_offset, bool is_tail) { |
2427 | assert(!jcp.is_1stconv); // This method does not support nchw data |
2428 | int kw = jcp.kw; |
2429 | |
2430 | int dst_count = 0; |
2431 | |
2432 | int ic_block_step_idx = src_offset / (jcp.typesize_in * ic_block_step); |
2433 | |
2434 | int pipeline_length = (isa_has_bf16(jcp.isa)) |
2435 | ? nstl::max(1, nstl::min(4, ur_w / 2)) |
2436 | : 1; |
2437 | may_be_set_oc_tail_mask(); |
2438 | |
2439 | const int dst_off_reg = (!isa_has_bf16(jcp.isa)) ? 26 : 31; |
2440 | auto load_dst = [=](int c) { |
2441 | bool is_tail = ur_w % 2 && c * 2 + 2 >= ur_w; |
2442 | bool is_ddst_nxc = is_ddst_layout_nxc(); |
2443 | auto offset = get_ddst_offset(2 * c) + ddst_offset; |
2444 | |
2445 | Opmask load_mask = is_ddst_nxc || is_tail ? m_0000ffff : m_ffffffff; |
2446 | vmovdqu16(Zmm(dst_off_reg - c % pipeline_length) | load_mask | T_z, |
2447 | EVEX_compress_addr(reg_ddst, offset)); |
2448 | |
2449 | if (is_ddst_nxc && !is_tail) { |
2450 | offset += get_ddst_offset(1) - 32; |
2451 | vmovdqu16(Zmm(dst_off_reg - c % pipeline_length) | m_ffff0000, |
2452 | EVEX_compress_addr(reg_ddst, offset)); |
2453 | } |
2454 | vpermw(Zmm(dst_off_reg - c % pipeline_length), get_perm_reg(), |
2455 | Zmm(dst_off_reg - c % pipeline_length)); |
2456 | }; |
2457 | |
2458 | for (int i_kw = 0; i_kw < kw; i_kw++) |
2459 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) |
2460 | vmovups(Zmm(i_kw * ic_block_step + i_ic), |
2461 | EVEX_compress_addr(reg_kernel, |
2462 | get_kernel_offset(i_ic, i_kw) + kernel_offset)); |
2463 | |
2464 | for (dst_count = 0; dst_count < pipeline_length; dst_count++) { |
2465 | load_dst(dst_count); |
2466 | } |
2467 | auto get_bcast_ptr = [=](int i_ur, int i_kw, int ic) { |
2468 | int scale = 2 * jcp.typesize_in; |
2469 | return rsp + b_ic * scale + permw_buffer_start + (i_ur + i_kw) * 64 |
2470 | + jcp.typesize_in * 2 |
2471 | * (ic_block_step_idx * ic_block_step + ic); |
2472 | }; |
2473 | |
2474 | for (int i_ur = 0; i_ur < ur_w; i_ur += 2) { |
2475 | for (int i_kw = 0; i_kw < kw; i_kw++) { |
2476 | if (check_borders(ur_w, pad_l, pad_r, i_ur, i_kw)) { |
2477 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2478 | if (!isa_has_bf16(jcp.isa)) { |
2479 | auto zmm_src = Zmm(28); |
2480 | vpbroadcastd( |
2481 | zmm_src, ptr[get_bcast_ptr(i_ur, i_kw, i_ic)]); |
2482 | bf16_emu_->vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic), |
2483 | Zmm(dst_off_reg - dst_count % pipeline_length), |
2484 | zmm_src); |
2485 | } else { |
2486 | vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic), |
2487 | Zmm(dst_off_reg - dst_count % pipeline_length), |
2488 | zword_b[get_bcast_ptr(i_ur, i_kw, i_ic)]); |
2489 | } |
2490 | } |
2491 | } |
2492 | } |
2493 | if (dst_count * 2 < ur_w) load_dst(dst_count); |
2494 | dst_count++; |
2495 | } |
2496 | for (int i_kw = 0; i_kw < kw; i_kw++) { |
2497 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2498 | auto l_offset = get_kernel_offset(i_ic, i_kw); |
2499 | vmovups(EVEX_compress_addr(reg_kernel, l_offset + kernel_offset), |
2500 | Zmm(i_kw * ic_block_step + i_ic)); |
2501 | } |
2502 | } |
2503 | |
2504 | may_be_reset_oc_tail_mask(); |
2505 | } |
2506 | |
2507 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
2508 | compute_diff_bias_init() { |
2509 | auto reg_unit_val = reg_tmp.cvt16(); |
2510 | mov(reg_unit_val, 0x3f80); // bf16 value of 1. |
2511 | vpbroadcastw(vreg_bias_unit, reg_unit_val); |
2512 | |
2513 | mov(reg_tmp, ptr[param + GET_OFF(bias)]); |
2514 | vmovups(vreg_bias_acc, ptr[reg_tmp]); |
2515 | |
2516 | if (jcp.uses_permw_transposition) { |
2517 | mov(reg_tmp, dst_prm_table); |
2518 | vmovups(get_perm_reg(), ptr[reg_tmp]); |
2519 | } |
2520 | } |
2521 | |
2522 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_diff_bias_row( |
2523 | bool is_partial) { |
2524 | if (!jcp.with_bias) return; |
2525 | mov(reg_tmp, ptr[param + GET_OFF(flags)]); |
2526 | Label skip_label; |
2527 | test(reg_tmp, FLAG_IC_FIRST); |
2528 | jz(skip_label, T_NEAR); |
2529 | |
2530 | may_be_set_oc_tail_mask(); |
2531 | |
2532 | if (is_partial) compute_diff_bias_init(); |
2533 | |
2534 | auto compute_step = [&](bool is_tail) { |
2535 | if (jcp.transpose_dst) { |
2536 | UNUSED(is_tail); |
2537 | vmovups(vreg_bias_ddst, ptr[reg_ddst]); |
2538 | } else { |
2539 | auto vreg_ddst_load = is_ddst_layout_nxc() || is_tail |
2540 | ? vreg_bias_ddst | m_0000ffff | T_z |
2541 | : vreg_bias_ddst; |
2542 | vmovdqu16(vreg_ddst_load, ptr[reg_ddst]); |
2543 | if (is_ddst_layout_nxc() && !is_tail) { |
2544 | const int shift_16_elems = 16 * jcp.typesize_in; |
2545 | vmovdqu16(vreg_bias_ddst | m_ffff0000, |
2546 | ptr[reg_ddst + get_ddst_offset(1) - shift_16_elems]); |
2547 | } |
2548 | vpermw(vreg_bias_ddst, get_perm_reg(), vreg_bias_ddst); |
2549 | } |
2550 | if (!isa_has_bf16(jcp.isa)) |
2551 | bf16_emu_->vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit); |
2552 | else |
2553 | vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit); |
2554 | }; |
2555 | |
2556 | Label ow_loop, ow_tail; |
2557 | int niters = jcp.tr_ow / 2; |
2558 | if (niters > 0) { |
2559 | mov(reg_tmp, jcp.tr_ow / 2); |
2560 | L(ow_loop); |
2561 | compute_step(false); |
2562 | add(reg_ddst, get_ddst_offset(2)); |
2563 | sub(reg_tmp, 1); |
2564 | jnz(ow_loop, T_NEAR); |
2565 | } |
2566 | if (jcp.tr_ow % 2) compute_step(true); |
2567 | |
2568 | if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters)); |
2569 | |
2570 | if (is_partial) { |
2571 | mov(reg_tmp, ptr[param + GET_OFF(bias)]); |
2572 | vmovups(ptr[reg_tmp], vreg_bias_acc); |
2573 | } |
2574 | |
2575 | may_be_reset_oc_tail_mask(); |
2576 | |
2577 | L(skip_label); |
2578 | } |
2579 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
2580 | maybe_compute_diff_bias() { |
2581 | // In harness_3d_reduction case calculation of diff_bias is called |
2582 | // for every ow row separately to be aligned with od loop in |
2583 | // compute_od_loop_common() |
2584 | if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return; |
2585 | mov(reg_tmp, ptr[param + GET_OFF(flags)]); |
2586 | |
2587 | Label skip_label; |
2588 | test(reg_tmp, FLAG_IC_FIRST); |
2589 | jz(skip_label, T_NEAR); |
2590 | |
2591 | switch (jcp.harness) { |
2592 | case harness_2d_reduction: |
2593 | mov(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
2594 | sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]); |
2595 | break; |
2596 | case harness_mb_reduction: |
2597 | case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break; |
2598 | case harness_3d_reduction: |
2599 | default: assert(!"Invalid harness type" ); |
2600 | } |
2601 | |
2602 | compute_diff_bias_init(); |
2603 | |
2604 | cmp(reg_oj, 0); |
2605 | jle(skip_label, T_NEAR); // nothing to do |
2606 | Label bias_loop; |
2607 | L(bias_loop); |
2608 | { |
2609 | compute_diff_bias_row(false); |
2610 | add(reg_ddst, get_ddst_offset(0, 1)); |
2611 | |
2612 | sub(reg_oj, 1); |
2613 | jnz(bias_loop, T_NEAR); |
2614 | } |
2615 | |
2616 | mov(reg_tmp, ptr[param + GET_OFF(bias)]); |
2617 | vmovups(ptr[reg_tmp], vreg_bias_acc); |
2618 | |
2619 | // restore reg_ddst value |
2620 | mov(reg_ddst, ptr[param + GET_OFF(dst)]); |
2621 | |
2622 | L(skip_label); |
2623 | } |
2624 | |
2625 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_ic_block_step( |
2626 | int ur_w, int pad_l, int pad_r, int ic_block_step, int src_offset, |
2627 | int kernel_offset, int ddst_offset, bool is_tail) { |
2628 | |
2629 | if (jcp.uses_permw_transposition) |
2630 | if (jcp.kernel_kind == expl_bcast) |
2631 | compute_ic_block_step_vpermw_expl(ur_w, pad_l, pad_r, ic_block_step, |
2632 | src_offset, kernel_offset, ddst_offset, is_tail); |
2633 | else |
2634 | compute_ic_block_step_vpermw(ur_w, pad_l, pad_r, ic_block_step, |
2635 | src_offset, kernel_offset, ddst_offset, is_tail); |
2636 | else if (jcp.is_1stconv && !jcp.transpose_src && jcp.stride_w > 1) |
2637 | compute_ic_block_step_interleave(ur_w, pad_l, pad_r, ic_block_step, |
2638 | src_offset, kernel_offset, ddst_offset, is_tail); |
2639 | else |
2640 | compute_ic_block_step_extern(ur_w, pad_l, pad_r, ic_block_step, |
2641 | src_offset, kernel_offset, ddst_offset, is_tail); |
2642 | } |
2643 | |
2644 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::get_ur_w( |
2645 | int &ur_w, int &ur_w_tail, int &ur_w_trips) { |
2646 | if (jcp.tr_ow <= max_ur_w) { |
2647 | ur_w = jcp.tr_ow; |
2648 | ur_w_tail = 0; |
2649 | ur_w_trips = 1; |
2650 | return; |
2651 | } |
2652 | |
2653 | int r_pad = 0; |
2654 | if (!jcp.transpose_src) { |
2655 | // If jcp.transpose_src, the buffer has physical padding |
2656 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
2657 | r_pad = nstl::max(0, |
2658 | calculate_end_padding( |
2659 | jcp.l_pad, jcp.tr_ow, jcp.tr_iw, jcp.stride_w, ext_kw)); |
2660 | } |
2661 | int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad; |
2662 | ur_w = max_ur_w; |
2663 | ur_w_trips = jcp.tr_ow / ur_w; |
2664 | ur_w_tail = jcp.tr_ow % ur_w; |
2665 | if ((ur_w_tail == 0 && jcp.r_pad != 0) || r_pad >= ur_w_tail) { |
2666 | if (ur_w_trips > 1) { |
2667 | ur_w_tail += ur_w; |
2668 | ur_w_trips--; |
2669 | } else { |
2670 | int ur_w_tail_total = ur_w + ur_w_tail; |
2671 | ur_w = (ur_w_tail_total % 4 == 0) ? ur_w_tail / 2 |
2672 | : ur_w_tail / 2 + 1; |
2673 | ur_w_tail = ur_w_tail_total - ur_w; |
2674 | if (l_pad > ur_w / 2) { |
2675 | ur_w = (l_pad % 2 == 0) ? l_pad : l_pad + 1; |
2676 | ur_w_tail = ur_w_tail_total - ur_w; |
2677 | } else if (r_pad > ur_w_tail) { |
2678 | ur_w_tail = (r_pad % 2 == 0) ? r_pad : r_pad + 1; |
2679 | ur_w = ur_w_tail_total - ur_w_tail; |
2680 | } |
2681 | } |
2682 | } |
2683 | } |
2684 | |
2685 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 :: |
2686 | compute_oh_step_unroll_ow_icblock(int ic_block_step) { |
2687 | Label kh_label, kd_label; |
2688 | |
2689 | int ic_block = jcp.ic_block; |
2690 | int ic_tail = jcp.ic_tail; |
2691 | int ow = jcp.tr_ow; |
2692 | int r_pad = 0; |
2693 | int ur_w, ur_w_tail, ur_w_trips; |
2694 | get_ur_w(ur_w, ur_w_tail, ur_w_trips); |
2695 | assert(ur_w_tail == 0 && ur_w_trips == 1); |
2696 | |
2697 | if (!jcp.transpose_src) { |
2698 | // If jcp.transpose_src, the buffer has physical padding |
2699 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
2700 | int iw = jcp.tr_iw; |
2701 | r_pad = nstl::max(0, |
2702 | calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw)); |
2703 | } |
2704 | int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad; |
2705 | |
2706 | if (jcp.ndims == 5) { |
2707 | L(kd_label); |
2708 | mov(reg_src, aux_reg_src); |
2709 | mov(reg_kernel, aux_reg_kernel); |
2710 | } |
2711 | |
2712 | mov(kj, reg_kh); |
2713 | L(kh_label); |
2714 | { |
2715 | const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1; |
2716 | // icb loop is supported for nxc layout only |
2717 | assert(IMPLICATION(generate_icb_loop, |
2718 | is_src_layout_nxc() && is_ddst_layout_nxc())); |
2719 | Label icb_block_label, icb_block_label_end; |
2720 | if (generate_icb_loop || ic_tail) { |
2721 | mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel); |
2722 | mov(ptr[rsp + icb_loop_src_ptr], reg_src); |
2723 | mov(reg_icb, ptr[param + GET_OFF(reduce_work)]); |
2724 | L(icb_block_label); |
2725 | } |
2726 | |
2727 | if (jcp.uses_permw_transposition) { |
2728 | convert_src_to_vnni_format(ur_w, l_pad, r_pad, 0); |
2729 | xor_(b_ic, b_ic); |
2730 | } |
2731 | |
2732 | const int ic_tail_loop_work = rnd_up(ic_tail, ic_block_step); |
2733 | for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) { |
2734 | const int src_offset = get_src_offset(i_b_ic, 0); |
2735 | compute_ic_block_step(ur_w, l_pad, r_pad, ic_block_step, src_offset, |
2736 | get_kernel_offset(i_b_ic, 0), 0, true); |
2737 | if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step); |
2738 | // We relax the boundary for reg_icb, as the src is already |
2739 | // converted to vnni_format with appropriate padding either through |
2740 | // transpose_src or convert_to_src_to_vnni_format. We can safely |
2741 | // allow compute_ic_block_step overstep as it operates on buffer |
2742 | // instead of src. |
2743 | if (ic_tail && i_b_ic + ic_block_step == ic_tail_loop_work) { |
2744 | assert(jcp.transpose_src || jcp.uses_permw_transposition); |
2745 | cmp(reg_icb, 0); |
2746 | jle(icb_block_label_end, T_NEAR); |
2747 | } |
2748 | } |
2749 | L(icb_block_label_end); |
2750 | |
2751 | const auto src_icb_loop_shift_bytes = get_src_offset(ic_block, 0); |
2752 | const auto kernel_icb_loop_shift_bytes |
2753 | = get_kernel_offset(0, jcp.kd * jcp.kh * jcp.kw); |
2754 | if (generate_icb_loop) { |
2755 | add(reg_src, src_icb_loop_shift_bytes); |
2756 | safe_add(reg_kernel, kernel_icb_loop_shift_bytes, reg_long_offt); |
2757 | |
2758 | assert(jcp.uses_permw_transposition); |
2759 | cmp(reg_icb, 0); |
2760 | jg(icb_block_label, T_NEAR); |
2761 | } |
2762 | |
2763 | if (generate_icb_loop || ic_tail) { |
2764 | // restore pointers |
2765 | mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]); |
2766 | mov(reg_src, ptr[rsp + icb_loop_src_ptr]); |
2767 | } |
2768 | |
2769 | add(reg_src, get_src_offset(0, 0, filter_h_to_src(1))); |
2770 | add(reg_kernel, get_kernel_offset(0, jcp.kw)); |
2771 | dec(kj); |
2772 | cmp(kj, 0); |
2773 | jg(kh_label, T_NEAR); |
2774 | } |
2775 | |
2776 | if (jcp.ndims == 5) { |
2777 | add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1))); |
2778 | add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw)); |
2779 | dec(ki); |
2780 | cmp(ki, 0); |
2781 | jg(kd_label, T_NEAR); |
2782 | } |
2783 | } |
2784 | |
2785 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 :: |
2786 | compute_oh_step_unroll_ow(int ic_block_step) { |
2787 | Label kh_label, ic_block_label, kd_label; |
2788 | |
2789 | int ic_block = jcp.ic_block; |
2790 | const int ic_tail = jcp.ic_tail; |
2791 | int ow = jcp.tr_ow; |
2792 | |
2793 | int r_pad = 0; |
2794 | int ur_w, ur_w_tail, ur_w_trips; |
2795 | get_ur_w(ur_w, ur_w_tail, ur_w_trips); |
2796 | assert(ur_w_tail == 0 && ur_w_trips == 1); |
2797 | |
2798 | if (!jcp.transpose_src) { |
2799 | // If jcp.transpose_src, the buffer has physical padding |
2800 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
2801 | int iw = jcp.tr_iw; |
2802 | r_pad = nstl::max(0, |
2803 | calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw)); |
2804 | } |
2805 | int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad; |
2806 | |
2807 | if (jcp.ndims == 5) { |
2808 | L(kd_label); |
2809 | mov(reg_src, aux_reg_src); |
2810 | mov(reg_kernel, aux_reg_kernel); |
2811 | } |
2812 | |
2813 | mov(kj, reg_kh); |
2814 | L(kh_label); |
2815 | { |
2816 | size_t src_offset = get_src_offset(ic_block_step, 0); |
2817 | |
2818 | const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1; |
2819 | // icb loop is supported for nxc layout only |
2820 | assert(IMPLICATION(generate_icb_loop, |
2821 | is_src_layout_nxc() && is_ddst_layout_nxc())); |
2822 | Label icb_block_label, icb_block_label_end; |
2823 | if (generate_icb_loop || ic_tail) { |
2824 | mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel); |
2825 | mov(ptr[rsp + icb_loop_src_ptr], reg_src); |
2826 | mov(reg_icb, ptr[param + GET_OFF(reduce_work)]); |
2827 | L(icb_block_label); |
2828 | } |
2829 | |
2830 | xor_(b_ic, b_ic); |
2831 | if (jcp.uses_permw_transposition) { |
2832 | convert_src_to_vnni_format(ow, l_pad, r_pad, 0); |
2833 | xor_(b_ic, b_ic); |
2834 | } |
2835 | |
2836 | L(ic_block_label); |
2837 | { |
2838 | compute_ic_block_step( |
2839 | ur_w, l_pad, r_pad, ic_block_step, 0, 0, 0, true); |
2840 | assert(jcp.ic_block % jcp.ic_block_step == 0); |
2841 | safe_add(reg_src, src_offset, reg_long_offt); |
2842 | add(reg_kernel, get_kernel_offset(ic_block_step, 0)); |
2843 | add(b_ic, ic_block_step); |
2844 | if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step); |
2845 | // We relax the boundary for reg_icb, as the src is already |
2846 | // converted to vnni_format with appropriate padding either through |
2847 | // transpose_src or convert_to_src_to_vnni_format. We can safely |
2848 | // allow compute_ic_block_step overstep as it operates on buffer |
2849 | // instead of src. |
2850 | if (ic_tail) { |
2851 | assert(jcp.transpose_src || jcp.uses_permw_transposition); |
2852 | cmp(reg_icb, 0); |
2853 | jle(icb_block_label_end, T_NEAR); |
2854 | } |
2855 | cmp(b_ic, jcp.ic_block); |
2856 | jl(ic_block_label, T_NEAR); |
2857 | } |
2858 | L(icb_block_label_end); |
2859 | |
2860 | if (jcp.uses_permw_transposition) { |
2861 | if (generate_icb_loop || ic_tail) { |
2862 | // substract pointer shift made within ic block loop |
2863 | // and move to next ic block |
2864 | safe_add(reg_kernel, |
2865 | get_kernel_offset(-ic_block, jcp.kd * jcp.kh * jcp.kw), |
2866 | reg_long_offt); |
2867 | |
2868 | cmp(reg_icb, 0); |
2869 | jg(icb_block_label, T_NEAR); |
2870 | // restore pointers |
2871 | mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]); |
2872 | mov(reg_src, ptr[rsp + icb_loop_src_ptr]); |
2873 | |
2874 | add(reg_src, get_src_offset(0, 0, filter_h_to_src(1))); |
2875 | add(reg_kernel, get_kernel_offset(0, jcp.kw)); |
2876 | } else { |
2877 | add(reg_src, |
2878 | get_src_offset(0, 0, filter_h_to_src(1)) |
2879 | - jcp.typesize_in * ic_block); |
2880 | } |
2881 | } else if (ic_tail) { |
2882 | // restore pointers |
2883 | mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]); |
2884 | mov(reg_src, ptr[rsp + icb_loop_src_ptr]); |
2885 | |
2886 | add(reg_src, get_src_offset(0, 0, filter_h_to_src(1))); |
2887 | add(reg_kernel, get_kernel_offset(0, jcp.kw)); |
2888 | } else if (jcp.is_1stconv && !jcp.transpose_src) { |
2889 | // Fixup reg_src to point to the correct location |
2890 | safe_add(reg_src, |
2891 | get_src_offset(0, 0, filter_h_to_src(1)) |
2892 | - src_offset * (jcp.ic_block / ic_block_step), |
2893 | reg_long_offt); |
2894 | } else { |
2895 | if (jcp.dilate_h > 0) |
2896 | add(reg_src, get_src_offset(0, 0, jcp.dilate_h)); |
2897 | } |
2898 | if (!generate_icb_loop && !ic_tail) |
2899 | // substract pointer shift made within ic block loop |
2900 | // and move to next kh index |
2901 | add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw)); |
2902 | dec(kj); |
2903 | cmp(kj, 0); |
2904 | jg(kh_label, T_NEAR); |
2905 | } |
2906 | if (jcp.ndims == 5) { |
2907 | add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1))); |
2908 | add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw)); |
2909 | dec(ki); |
2910 | cmp(ki, 0); |
2911 | jg(kd_label, T_NEAR); |
2912 | } |
2913 | } |
2914 | |
2915 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_oh_step_common( |
2916 | int ic_block_step) { |
2917 | Label kh_label, ic_block_label, ow_block_label, kd_label; |
2918 | |
2919 | int ic_block = jcp.ic_block; |
2920 | int ic_tail = jcp.ic_tail; |
2921 | int ow = jcp.tr_ow; |
2922 | int r_pad = 0; |
2923 | if (!jcp.transpose_src) { |
2924 | // If jcp.transpose_src, the buffer has physical padding |
2925 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
2926 | int iw = jcp.tr_iw; |
2927 | r_pad = nstl::max(0, |
2928 | calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw)); |
2929 | } |
2930 | int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad; |
2931 | |
2932 | int ur_w, ur_w_trips, ur_w_tail; |
2933 | get_ur_w(ur_w, ur_w_tail, ur_w_trips); |
2934 | assert(l_pad <= ur_w); |
2935 | assert(r_pad <= ur_w_tail); |
2936 | |
2937 | auto src_comeback |
2938 | = get_src_offset(0, filter_w_to_src(0, ur_w_trips * ur_w, l_pad)); |
2939 | auto ddst_comeback = get_ddst_offset(ur_w_trips * ur_w); |
2940 | |
2941 | if (jcp.ndims == 5) { |
2942 | L(kd_label); |
2943 | mov(reg_src, aux_reg_src); |
2944 | mov(reg_kernel, aux_reg_kernel); |
2945 | } |
2946 | |
2947 | bool use_kh_ic_ow_loop_order = !jcp.uses_permw_transposition; |
2948 | if (use_kh_ic_ow_loop_order) { |
2949 | assert(!jcp.uses_permw_transposition); |
2950 | |
2951 | auto ic_loop = [=](int ic_block_step) { |
2952 | Label ow_block_label; |
2953 | // create a local copy |
2954 | int ur_w_blocks = ur_w_trips; |
2955 | auto src_offset = get_src_offset(ic_block_step, 0); |
2956 | if (l_pad != 0) { |
2957 | ur_w_blocks--; |
2958 | compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0); |
2959 | add(reg_src, |
2960 | get_src_offset(0, filter_w_to_src(0, ur_w, l_pad))); |
2961 | add(reg_ddst, get_ddst_offset(ur_w)); |
2962 | } |
2963 | |
2964 | if (ur_w_blocks > 0) { |
2965 | xor_(reg_ur_w_trips, reg_ur_w_trips); |
2966 | L(ow_block_label); |
2967 | { |
2968 | compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); |
2969 | add(reg_src, |
2970 | get_src_offset(0, filter_w_to_src(0, ur_w, 0))); |
2971 | add(reg_ddst, get_ddst_offset(ur_w)); |
2972 | |
2973 | inc(reg_ur_w_trips); |
2974 | cmp(reg_ur_w_trips, ur_w_blocks); |
2975 | jl(ow_block_label, T_NEAR); |
2976 | } |
2977 | } |
2978 | |
2979 | if (ur_w_tail > 0) { |
2980 | compute_ic_block_step( |
2981 | ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0, true); |
2982 | } |
2983 | |
2984 | sub(reg_src, src_comeback); |
2985 | sub(reg_ddst, ddst_comeback); |
2986 | |
2987 | safe_add(reg_src, src_offset, reg_long_offt); |
2988 | add(reg_kernel, get_kernel_offset(ic_block_step, 0)); |
2989 | }; |
2990 | |
2991 | mov(kj, reg_kh); |
2992 | L(kh_label); |
2993 | { |
2994 | Label ic_tail_label, skip_ic_tail_offset_compensation; |
2995 | if (ic_tail) { |
2996 | // It appears currently, generate_icb_loop is not enabled here, |
2997 | // implying at most one icb is processed. |
2998 | assert(jcp.nb_ic_blocking_max == 1); |
2999 | mov(reg_icb, ptr[param + GET_OFF(reduce_work)]); |
3000 | } else { |
3001 | mov(reg_icb, ic_block); |
3002 | } |
3003 | |
3004 | L(ic_block_label); |
3005 | { |
3006 | ic_loop(ic_block_step); |
3007 | sub(reg_icb, ic_block_step); |
3008 | // We relax the boundary for reg_icb, as the src is already |
3009 | // converted to vnni_format with appropriate padding either |
3010 | // through transpose_src or convert_to_src_to_vnni_format. We |
3011 | // can safely allow compute_ic_block_step overstep as it |
3012 | // operates on buffer instead of src. |
3013 | if (ic_tail) { |
3014 | assert(jcp.transpose_src || jcp.uses_permw_transposition); |
3015 | } |
3016 | cmp(reg_icb, 0); |
3017 | jg(ic_block_label, T_NEAR); |
3018 | } |
3019 | |
3020 | if (ic_tail) { |
3021 | mov(reg_icb, ptr[param + GET_OFF(reduce_work)]); |
3022 | cmp(reg_icb, jcp.simd_w); |
3023 | je(skip_ic_tail_offset_compensation); |
3024 | add(reg_kernel, |
3025 | get_kernel_offset( |
3026 | jcp.ic_block - rnd_up(ic_tail, ic_block_step), |
3027 | 0)); |
3028 | safe_add(reg_src, |
3029 | get_src_offset(0, 0, filter_h_to_src(1)) |
3030 | - get_src_offset( |
3031 | rnd_up(ic_tail, ic_block_step), 0), |
3032 | reg_long_offt); |
3033 | L(skip_ic_tail_offset_compensation); |
3034 | } |
3035 | if (jcp.is_1stconv && !jcp.transpose_src) { |
3036 | // Fixup reg_src to point to the correct location |
3037 | auto src_offset = get_src_offset(ic_block_step, 0); |
3038 | safe_add(reg_src, |
3039 | get_src_offset(0, 0, filter_h_to_src(1)) |
3040 | - src_offset * (jcp.ic_block / ic_block_step), |
3041 | reg_long_offt); |
3042 | } else if (jcp.dilate_h > 0) { |
3043 | add(reg_src, get_src_offset(0, 0, jcp.dilate_h)); |
3044 | } |
3045 | // substract pointer shift made within ic block loop |
3046 | // and move to next kh index |
3047 | add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw)); |
3048 | dec(kj); |
3049 | cmp(kj, 0); |
3050 | jg(kh_label, T_NEAR); |
3051 | } |
3052 | } else { |
3053 | assert(!jcp.is_1stconv); |
3054 | auto src_icbstep_shift = get_src_offset(1, 0); |
3055 | |
3056 | auto ic_loop = [=](int ic_block_step) { |
3057 | int ic_work = ic_block; |
3058 | Label ow_block_label, ic_block_label_padl, ic_block_label_general, |
3059 | ic_block_label_tail; |
3060 | int ur_w_blocks = ur_w_trips; |
3061 | if (l_pad != 0) { |
3062 | ur_w_blocks--; |
3063 | xor_(b_ic, b_ic); |
3064 | if (jcp.uses_permw_transposition) { |
3065 | convert_src_to_vnni_format(ur_w, l_pad, 0, 0); |
3066 | } |
3067 | L(ic_block_label_padl); |
3068 | { |
3069 | compute_ic_block_step( |
3070 | ur_w, l_pad, 0, ic_block_step, 0, 0, 0); |
3071 | safe_add(reg_src, src_icbstep_shift * ic_block_step, |
3072 | reg_long_offt); |
3073 | add(reg_kernel, get_kernel_offset(ic_block_step, 0)); |
3074 | |
3075 | add(b_ic, ic_block_step); |
3076 | cmp(b_ic, ic_work); |
3077 | jl(ic_block_label_padl, T_NEAR); |
3078 | } |
3079 | safe_sub(reg_src, src_icbstep_shift * ic_work, reg_long_offt); |
3080 | sub(reg_kernel, get_kernel_offset(ic_work, 0)); |
3081 | add(reg_src, |
3082 | get_src_offset(0, filter_w_to_src(0, ur_w, l_pad))); |
3083 | add(reg_ddst, get_ddst_offset(ur_w)); |
3084 | } |
3085 | |
3086 | if (ur_w_blocks > 0) { |
3087 | xor_(reg_ur_w_trips, reg_ur_w_trips); |
3088 | L(ow_block_label); |
3089 | { |
3090 | if (jcp.uses_permw_transposition) { |
3091 | convert_src_to_vnni_format(ur_w, 0, 0, 0); |
3092 | } |
3093 | xor_(b_ic, b_ic); |
3094 | L(ic_block_label_general); |
3095 | { |
3096 | compute_ic_block_step( |
3097 | ur_w, 0, 0, ic_block_step, 0, 0, 0); |
3098 | safe_add(reg_src, src_icbstep_shift * ic_block_step, |
3099 | reg_long_offt); |
3100 | add(reg_kernel, get_kernel_offset(ic_block_step, 0)); |
3101 | |
3102 | add(b_ic, ic_block_step); |
3103 | cmp(b_ic, ic_work); |
3104 | jl(ic_block_label_general, T_NEAR); |
3105 | } |
3106 | safe_sub(reg_src, src_icbstep_shift * ic_work, |
3107 | reg_long_offt); |
3108 | sub(reg_kernel, get_kernel_offset(ic_work, 0)); |
3109 | add(reg_src, get_src_offset(0, filter_w_to_src(0, ur_w))); |
3110 | add(reg_ddst, get_ddst_offset(ur_w)); |
3111 | |
3112 | inc(reg_ur_w_trips); |
3113 | cmp(reg_ur_w_trips, ur_w_blocks); |
3114 | jl(ow_block_label, T_NEAR); |
3115 | } |
3116 | } |
3117 | |
3118 | if (ur_w_tail > 0) { |
3119 | if (jcp.uses_permw_transposition) { |
3120 | convert_src_to_vnni_format(ur_w_tail, 0, r_pad, 0); |
3121 | } |
3122 | xor_(b_ic, b_ic); |
3123 | L(ic_block_label_tail); |
3124 | { |
3125 | compute_ic_block_step( |
3126 | ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0, true); |
3127 | safe_add(reg_src, src_icbstep_shift * ic_block_step, |
3128 | reg_long_offt); |
3129 | add(reg_kernel, get_kernel_offset(ic_block_step, 0)); |
3130 | |
3131 | add(b_ic, ic_block_step); |
3132 | cmp(b_ic, ic_work); |
3133 | jl(ic_block_label_tail, T_NEAR); |
3134 | } |
3135 | safe_sub(reg_src, src_icbstep_shift * ic_work, reg_long_offt); |
3136 | sub(reg_kernel, get_kernel_offset(ic_work, 0)); |
3137 | } |
3138 | |
3139 | sub(reg_src, src_comeback); |
3140 | sub(reg_ddst, ddst_comeback); |
3141 | }; |
3142 | |
3143 | mov(kj, reg_kh); |
3144 | L(kh_label); |
3145 | { |
3146 | const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1; |
3147 | // icb loop is supported for nxc layout only |
3148 | assert(IMPLICATION(generate_icb_loop, |
3149 | is_src_layout_nxc() && is_ddst_layout_nxc())); |
3150 | Label icb_block_label, icb_block_label_cb, ic_tail_loop_label; |
3151 | |
3152 | if (generate_icb_loop) { |
3153 | mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel); |
3154 | mov(ptr[rsp + icb_loop_src_ptr], reg_src); |
3155 | } |
3156 | if (ic_tail || generate_icb_loop) |
3157 | mov(reg_icb, ptr[param + GET_OFF(reduce_work)]); |
3158 | L(icb_block_label); |
3159 | |
3160 | ic_loop(ic_block_step); |
3161 | |
3162 | if (generate_icb_loop) { |
3163 | add(reg_src, get_src_offset(ic_block, 0)); |
3164 | safe_add(reg_kernel, |
3165 | get_kernel_offset(0, jcp.kd * jcp.kh * jcp.kw), |
3166 | reg_long_offt); |
3167 | sub(reg_icb, ic_block); |
3168 | cmp(reg_icb, 0); |
3169 | jg(icb_block_label, T_NEAR); |
3170 | } |
3171 | |
3172 | if (generate_icb_loop) { |
3173 | // restore pointers |
3174 | mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]); |
3175 | mov(reg_src, ptr[rsp + icb_loop_src_ptr]); |
3176 | } |
3177 | |
3178 | add(reg_src, get_src_offset(0, 0, filter_h_to_src(1))); |
3179 | add(reg_kernel, get_kernel_offset(0, jcp.kw)); |
3180 | dec(kj); |
3181 | cmp(kj, 0); |
3182 | jg(kh_label, T_NEAR); |
3183 | } |
3184 | } |
3185 | if (jcp.ndims == 5) { |
3186 | add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1))); |
3187 | add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw)); |
3188 | dec(ki); |
3189 | cmp(ki, 0); |
3190 | jg(kd_label, T_NEAR); |
3191 | } |
3192 | } |
3193 | |
3194 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_oh_step_disp() { |
3195 | int ic_block_step = jcp.ic_block_step; |
3196 | |
3197 | bool too_large_to_unroll = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1) |
3198 | && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1); |
3199 | |
3200 | int ow = jcp.tr_ow; |
3201 | if (jcp.ndims == 5) { |
3202 | /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of |
3203 | * 'movs' must be guaranteed. */ |
3204 | mov(ki, reg_kd_count); |
3205 | mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count); |
3206 | mov(aux_reg_src, reg_src); |
3207 | mov(aux_reg_kernel, reg_kernel); |
3208 | } |
3209 | if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) { |
3210 | compute_oh_step_unroll_ow_icblock(ic_block_step); |
3211 | } else if (ow <= max_ur_w) { |
3212 | compute_oh_step_unroll_ow(ic_block_step); |
3213 | } else { |
3214 | compute_oh_step_common(ic_block_step); |
3215 | } |
3216 | |
3217 | // In harness_3d_reduction case calculation of diff_bias is called |
3218 | // for every ow row separately to be aligned with od loop in |
3219 | // compute_od_loop_common() |
3220 | if (jcp.harness == harness_3d_reduction) compute_diff_bias_row(); |
3221 | if (jcp.ndims == 5) { |
3222 | mov(reg_src, aux_reg_src); |
3223 | mov(reg_kernel, aux_reg_kernel); |
3224 | mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset)); |
3225 | od_step_comeback_pointers(); |
3226 | } else { |
3227 | oh_step_comeback_pointers(); |
3228 | } |
3229 | } |
3230 | |
3231 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::maybe_zero_kernel() { |
3232 | if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return; |
3233 | Label skip_zeroing, zeroing_loop; |
3234 | |
3235 | mov(reg_tmp, ptr[param + GET_OFF(channel)]); |
3236 | cmp(reg_tmp, 0); |
3237 | jz(skip_zeroing, T_NEAR); |
3238 | |
3239 | Zmm zero = Zmm(0); |
3240 | vpxord(zero, zero, zero); |
3241 | if (jcp.with_bias) { |
3242 | Label skip_bias_zeroing; |
3243 | mov(reg_tmp, ptr[param + GET_OFF(flags)]); |
3244 | test(reg_tmp, FLAG_IC_FIRST); |
3245 | jz(skip_bias_zeroing, T_NEAR); |
3246 | |
3247 | mov(reg_tmp, ptr[param + GET_OFF(bias)]); |
3248 | vmovups(ptr[reg_tmp], zero); |
3249 | |
3250 | L(skip_bias_zeroing); |
3251 | if (jcp.harness == harness_compute_full_spatial) |
3252 | jmp(skip_zeroing, T_NEAR); |
3253 | } |
3254 | |
3255 | const size_t kernel_block_bytes |
3256 | = get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd); |
3257 | Label icb_block_label, icb_block_label_cb; |
3258 | |
3259 | const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1; |
3260 | // icb loop is supported for nxc layout only |
3261 | assert(IMPLICATION( |
3262 | generate_icb_loop, is_src_layout_nxc() && is_ddst_layout_nxc())); |
3263 | if (generate_icb_loop) { |
3264 | mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel); |
3265 | mov(reg_icb, ptr[param + GET_OFF(reduce_work)]); |
3266 | L(icb_block_label); |
3267 | } |
3268 | |
3269 | xor_(reg_tmp, reg_tmp); |
3270 | L(zeroing_loop); |
3271 | { |
3272 | assert(get_kernel_offset(1, 0) == cpu_isa_traits<avx512_core>::vlen); |
3273 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) |
3274 | vmovups(ptr[reg_kernel + reg_tmp + get_kernel_offset(ic1, 0)], |
3275 | zero); |
3276 | add(reg_tmp, get_kernel_offset(jcp.ic_block, 0)); |
3277 | cmp(reg_tmp, kernel_block_bytes); |
3278 | jnz(zeroing_loop); |
3279 | } |
3280 | |
3281 | if (generate_icb_loop) { |
3282 | add(reg_kernel, kernel_block_bytes); |
3283 | sub(reg_icb, jcp.ic_block); |
3284 | cmp(reg_icb, 0); |
3285 | jg(icb_block_label, T_NEAR); |
3286 | // restore pointer |
3287 | mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]); |
3288 | } |
3289 | |
3290 | L(skip_zeroing); |
3291 | } |
3292 | |
3293 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_oh_loop_common( |
3294 | bool is_partial) { |
3295 | int b_pad = jcp.b_pad; |
3296 | int t_pad = jcp.t_pad; |
3297 | bool is_dilated = jcp.dilate_h != 0; |
3298 | int dilate_h = jcp.dilate_h + 1; |
3299 | int stride_h = jcp.stride_h; |
3300 | auto filter_step_size = get_kernel_offset(0, jcp.kw); |
3301 | auto src_step_size = get_src_offset(0, 0, 1); |
3302 | auto ddst_step_size = get_ddst_offset(0, 1); |
3303 | Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end, |
3304 | oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label, |
3305 | oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift, |
3306 | oh_dilate_label_end, oh_dilate_setup_label_shift, |
3307 | oh_dilate_setup_label_noshift, oh_dilate_setup_label_end; |
3308 | |
3309 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
3310 | int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h); |
3311 | int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end); |
3312 | int oh_head_overflow_end = div_up(t_pad, stride_h); |
3313 | int oh_tail_end = jcp.oh; |
3314 | |
3315 | int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h; |
3316 | int ih_body_end |
3317 | = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset); |
3318 | |
3319 | if (is_partial) |
3320 | mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]); |
3321 | else |
3322 | xor_(reg_oj, reg_oj); |
3323 | |
3324 | /* Compute 'top' edge */ |
3325 | if (t_pad > 0) { |
3326 | if (is_partial) { |
3327 | cmp(reg_oj, oh_head_overflow_end); |
3328 | jge(oh_tpad_tail_label_end, T_NEAR); |
3329 | } |
3330 | const int overflow |
3331 | = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h)); |
3332 | const int underflow = div_up(t_pad, dilate_h); |
3333 | const int initial_kh = jcp.kh - overflow - underflow; |
3334 | |
3335 | // Setup reg_kh, reg_kernel, and reg_src |
3336 | mov(reg_kh, initial_kh); |
3337 | add(reg_kernel, filter_step_size * underflow); |
3338 | if (is_dilated) { |
3339 | const int tail = t_pad % dilate_h; |
3340 | const int shift = tail == 0 ? 0 : dilate_h - tail; |
3341 | mov(reg_ih_shift, shift); |
3342 | if (!is_partial) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift); |
3343 | add(reg_src, src_step_size * shift); |
3344 | } |
3345 | |
3346 | if (is_partial) { |
3347 | Label head_setup, head_setup_finish; |
3348 | cmp(reg_oj, 0); |
3349 | je(head_setup_finish, T_NEAR); |
3350 | mov(reg_oj_setup, reg_oj); |
3351 | |
3352 | L(head_setup); |
3353 | if (is_dilated) { |
3354 | inc(reg_ih_shift); |
3355 | cmp(reg_ih_shift, dilate_h); |
3356 | jl(oh_dilate_setup_label_shift, T_NEAR); |
3357 | // unshift src as new kernel element enters |
3358 | sub(reg_src, src_step_size * (dilate_h - 1)); |
3359 | xor_(reg_ih_shift, reg_ih_shift); |
3360 | } |
3361 | // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 |
3362 | add(reg_kh, stride_h); |
3363 | sub(reg_kernel, filter_step_size * stride_h); |
3364 | if (is_dilated) { |
3365 | jmp(oh_dilate_setup_label_noshift, T_NEAR); |
3366 | L(oh_dilate_setup_label_shift); |
3367 | // shift src as old kernel element progresses |
3368 | add(reg_src, src_step_size * stride_h); |
3369 | L(oh_dilate_setup_label_noshift); |
3370 | } |
3371 | sub(reg_oj_setup, 1); |
3372 | jg(head_setup, T_NEAR); |
3373 | L(head_setup_finish); |
3374 | |
3375 | if (is_dilated) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift); |
3376 | if (oh_head_end < oh_head_overflow_end) { |
3377 | cmp(reg_oj, oh_head_end); |
3378 | jge(oh_tpad_label_end, T_NEAR); |
3379 | } |
3380 | } |
3381 | |
3382 | //Setup reg_kernel |
3383 | // If dilated, shift src ptr |
3384 | // Loop |
3385 | L(oh_tpad_label); |
3386 | compute_oh_step_disp(); |
3387 | add(reg_ddst, ddst_step_size); |
3388 | if (is_dilated) { |
3389 | mov(reg_ih_shift, ptr[rsp + ih_dilate_shift]); |
3390 | inc(reg_ih_shift); |
3391 | mov(ptr[rsp + ih_dilate_shift], reg_ih_shift); |
3392 | cmp(reg_ih_shift, dilate_h); |
3393 | jl(oh_dilate_label_shift, T_NEAR); |
3394 | // unshift src as new kernel element enters |
3395 | sub(reg_src, src_step_size * (dilate_h - 1)); |
3396 | xor_(reg_ih_shift, reg_ih_shift); |
3397 | mov(ptr[rsp + ih_dilate_shift], reg_ih_shift); |
3398 | } |
3399 | // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 |
3400 | add(reg_kh, stride_h); |
3401 | sub(reg_kernel, filter_step_size * stride_h); |
3402 | if (is_dilated) { |
3403 | jmp(oh_dilate_label_noshift, T_NEAR); |
3404 | L(oh_dilate_label_shift); |
3405 | // shift src as old kernel element progresses |
3406 | add(reg_src, src_step_size * stride_h); |
3407 | L(oh_dilate_label_noshift); |
3408 | } |
3409 | inc(reg_oj); |
3410 | |
3411 | if (is_partial) { |
3412 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
3413 | jge(oh_bpad_label_end, T_NEAR); |
3414 | } |
3415 | cmp(reg_oj, oh_head_end); |
3416 | jl(oh_tpad_label, T_NEAR); |
3417 | |
3418 | L(oh_tpad_label_end); |
3419 | // need second loop to process kernel if it is larger than the src |
3420 | // (does not apply to dilations as they must have unit stride) |
3421 | if (oh_head_end < oh_head_overflow_end) { |
3422 | assert(!is_dilated); |
3423 | |
3424 | cmp(reg_oj, oh_head_overflow_end); |
3425 | jge(oh_tpad_tail_label_end, T_NEAR); |
3426 | |
3427 | mov(reg_kh, jcp.ih); |
3428 | L(oh_tpad_tail_label); |
3429 | { |
3430 | compute_oh_step_disp(); |
3431 | add(reg_ddst, ddst_step_size); |
3432 | sub(reg_kernel, filter_step_size * stride_h); |
3433 | |
3434 | inc(reg_oj); |
3435 | |
3436 | if (is_partial) { |
3437 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
3438 | jge(oh_bpad_label_end, T_NEAR); |
3439 | } |
3440 | cmp(reg_oj, oh_head_overflow_end); |
3441 | jl(oh_tpad_tail_label, T_NEAR); |
3442 | } |
3443 | } |
3444 | if (body_src_start_offset != 0) { |
3445 | add(reg_kernel, filter_step_size * body_src_start_offset); |
3446 | add(reg_src, src_step_size * body_src_start_offset); |
3447 | } |
3448 | L(oh_tpad_tail_label_end); |
3449 | } |
3450 | |
3451 | if (is_partial) { |
3452 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
3453 | jge(oh_bpad_label_end, T_NEAR); |
3454 | } |
3455 | cmp(reg_oj, oh_body_end); |
3456 | jge(oh_label_end, T_NEAR); |
3457 | |
3458 | /* Compute middle block(s) */ |
3459 | mov(reg_kh, jcp.kh); |
3460 | L(oh_label); |
3461 | { |
3462 | compute_oh_step_disp(); |
3463 | add(reg_src, src_step_size * stride_h); |
3464 | add(reg_ddst, ddst_step_size); |
3465 | |
3466 | inc(reg_oj); |
3467 | |
3468 | if (is_partial) { |
3469 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
3470 | jge(oh_bpad_label_end, T_NEAR); |
3471 | } |
3472 | |
3473 | cmp(reg_oj, oh_body_end); |
3474 | jl(oh_label, T_NEAR); |
3475 | } |
3476 | L(oh_label_end); |
3477 | |
3478 | /* Compute bottom edge */ |
3479 | if (b_pad > 0) { |
3480 | if (is_partial) { |
3481 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
3482 | jge(oh_bpad_label_end, T_NEAR); |
3483 | } |
3484 | cmp(reg_oj, jcp.oh); |
3485 | jge(oh_bpad_label_end, T_NEAR); |
3486 | |
3487 | if (is_dilated) { |
3488 | // Assumes unit stride for dilations |
3489 | mov(reg_kh, jcp.kh - 1); |
3490 | xor_(reg_ih_shift, reg_ih_shift); |
3491 | } else { |
3492 | assert(jcp.dilate_h == 0); |
3493 | mov(reg_kh, jcp.ih - ih_body_end); |
3494 | } |
3495 | if (is_partial) { |
3496 | lea(reg_oj_setup, |
3497 | ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]); |
3498 | if (stride_h == 1 && !is_dilated) { |
3499 | sub(reg_kh, reg_oj_setup); |
3500 | } else { |
3501 | Label body_setup, body_setup_finish, dilate_skip; |
3502 | cmp(reg_oj_setup, 0); |
3503 | je(body_setup_finish, T_NEAR); |
3504 | |
3505 | L(body_setup); |
3506 | if (is_dilated) { |
3507 | inc(reg_ih_shift); |
3508 | cmp(reg_ih_shift, dilate_h); |
3509 | jl(dilate_skip, T_NEAR); |
3510 | xor_(reg_ih_shift, reg_ih_shift); |
3511 | } |
3512 | sub(reg_kh, stride_h); |
3513 | L(dilate_skip); |
3514 | sub(reg_oj_setup, 1); |
3515 | jg(body_setup, T_NEAR); |
3516 | L(body_setup_finish); |
3517 | } |
3518 | } |
3519 | |
3520 | if (is_dilated) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift); |
3521 | L(oh_bpad_label); |
3522 | { |
3523 | compute_oh_step_disp(); |
3524 | add(reg_src, src_step_size * stride_h); |
3525 | add(reg_ddst, ddst_step_size); |
3526 | |
3527 | if (is_dilated) { |
3528 | mov(reg_ih_shift, ptr[rsp + ih_dilate_shift]); |
3529 | inc(reg_ih_shift); |
3530 | mov(ptr[rsp + ih_dilate_shift], reg_ih_shift); |
3531 | cmp(reg_ih_shift, dilate_h); |
3532 | jl(oh_dilate_label_end, T_NEAR); |
3533 | xor_(reg_ih_shift, reg_ih_shift); |
3534 | mov(ptr[rsp + ih_dilate_shift], reg_ih_shift); |
3535 | } |
3536 | sub(reg_kh, stride_h); |
3537 | L(oh_dilate_label_end); |
3538 | inc(reg_oj); |
3539 | if (is_partial) { |
3540 | cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]); |
3541 | jge(oh_bpad_label_end, T_NEAR); |
3542 | } |
3543 | cmp(reg_oj, oh_tail_end); |
3544 | jl(oh_bpad_label, T_NEAR); |
3545 | } |
3546 | } |
3547 | L(oh_bpad_label_end); |
3548 | } |
3549 | |
3550 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_od_loop_common( |
3551 | bool is_partial) { |
3552 | assert(jcp.harness == harness_3d_reduction); |
3553 | |
3554 | const int src_backpad_overlap |
3555 | = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d); |
3556 | |
3557 | const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw); |
3558 | const auto src_shift = get_src_offset(0, 0, jcp.ih); |
3559 | const auto ddst_shift = get_ddst_offset(0, jcp.oh); |
3560 | |
3561 | const int kd_front_pad = nstl::max(0, jcp.f_pad); |
3562 | const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id); |
3563 | |
3564 | Label d_loop_label, loop_end_label, common_block_label, fpad_end_label, |
3565 | backpad_end_label, backpad_label; |
3566 | |
3567 | /* initially offset 'kd' by f_pad */ |
3568 | mov(reg_src_d, ptr[param + GET_OFF(src)]); |
3569 | mov(reg_ddst_d, ptr[param + GET_OFF(dst)]); |
3570 | |
3571 | if (is_partial) { |
3572 | add(reg_kernel, ptr[param + GET_OFF(kd_offset)]); |
3573 | mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]); |
3574 | mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]); |
3575 | } else { |
3576 | const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad; |
3577 | const int kd_offset = get_kernel_offset( |
3578 | 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw); |
3579 | add(reg_kernel, kd_offset); |
3580 | xor_(reg_d_index, reg_d_index); |
3581 | mov(reg_kd_count, kd_padding); |
3582 | } |
3583 | |
3584 | cmp(reg_kd_count, 0); |
3585 | jle(loop_end_label, T_NEAR); // no iterations along kd |
3586 | if (is_partial) |
3587 | cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]); |
3588 | else |
3589 | cmp(reg_d_index, jcp.od); |
3590 | jge(loop_end_label, T_NEAR); // no iterations along depth dimension |
3591 | |
3592 | L(d_loop_label); |
3593 | |
3594 | mov(reg_src, reg_src_d); |
3595 | mov(reg_ddst, reg_ddst_d); |
3596 | |
3597 | mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d); |
3598 | mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d); |
3599 | mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index); |
3600 | |
3601 | compute_oh_loop_common(); |
3602 | |
3603 | mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset)); |
3604 | mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset)); |
3605 | mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset)); |
3606 | |
3607 | /* Compute 'front' edge */ |
3608 | if (jcp.f_pad > 0) { |
3609 | /* Check if within fpad region */ |
3610 | cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d)); |
3611 | jge(fpad_end_label, T_NEAR); |
3612 | |
3613 | /* Fpad steps */ |
3614 | sub(reg_kernel, filter_shift * jcp.stride_d); |
3615 | add(reg_kd_count, jcp.stride_d); |
3616 | |
3617 | /* Final number of kernel elements that overlap with src */ |
3618 | const int src_ker_overlap = nstl::min(jcp.kd, jcp.id); |
3619 | cmp(reg_kd_count, src_ker_overlap); |
3620 | jle(common_block_label, T_NEAR); |
3621 | |
3622 | /* Correct any excess shifts to kernel and src */ |
3623 | if (jcp.f_pad <= jcp.od * jcp.stride_d) { |
3624 | /* Filter has moved beyond padding (adjust for stride effects) */ |
3625 | if (jcp.f_pad % jcp.stride_d != 0) { |
3626 | int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d; |
3627 | add(reg_kernel, filter_shift * src_corr); |
3628 | add(reg_src_d, src_shift * src_corr); |
3629 | } |
3630 | } else { |
3631 | /* Filter still overlaps padding (complete reset) */ |
3632 | sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift); |
3633 | } |
3634 | |
3635 | /* Apply correction */ |
3636 | mov(reg_kd_count, src_ker_overlap); |
3637 | jmp(common_block_label); |
3638 | |
3639 | L(fpad_end_label); |
3640 | } |
3641 | |
3642 | /* Compute bottom edge */ |
3643 | if (jcp.back_pad > 0) { |
3644 | |
3645 | /* Check if within back_pad region */ |
3646 | cmp(reg_d_index, src_backpad_overlap - 1); |
3647 | jl(backpad_end_label, T_NEAR); |
3648 | jg(backpad_label, T_NEAR); |
3649 | |
3650 | /* Execute overlap correction between the filter and the initial |
3651 | * back_pad region. */ |
3652 | mov(reg_kd_count, |
3653 | jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d); |
3654 | jmp(backpad_end_label, T_NEAR); |
3655 | |
3656 | L(backpad_label); |
3657 | sub(reg_kd_count, jcp.stride_d); |
3658 | cmp(reg_kd_count, 0); |
3659 | jle(loop_end_label, T_NEAR); |
3660 | |
3661 | L(backpad_end_label); |
3662 | } |
3663 | |
3664 | /* Compute middle block */ |
3665 | add(reg_src_d, src_shift * jcp.stride_d); |
3666 | |
3667 | /* Execute common block and loop */ |
3668 | L(common_block_label); |
3669 | add(reg_ddst_d, ddst_shift); |
3670 | inc(reg_d_index); |
3671 | if (is_partial) |
3672 | cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]); |
3673 | else |
3674 | cmp(reg_d_index, jcp.od); |
3675 | jl(d_loop_label, T_NEAR); |
3676 | |
3677 | L(loop_end_label); |
3678 | } |
3679 | |
3680 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32:: |
3681 | compute_full_spat_loop() { |
3682 | // General code layout: |
3683 | // |
3684 | // Blocking over OH -- top level |
3685 | // (Reduces L2 pressure; not very useful right now) |
3686 | // Loop over all KHxKW kernel -- emit_kh_kw_loop() |
3687 | // Loop over OH block -- emit_h_loop() |
3688 | // Loop over OW blocks -- emit_fma_block() |
3689 | // (Supports both fully unrolled and partially unrolled |
3690 | // versions to reduce code size) |
3691 | // Loop over OW block -- emit_fma_step() |
3692 | |
3693 | auto src_row_size = get_src_offset(0, 0, 1); |
3694 | auto ddst_row_size = get_ddst_offset(0, 1); |
3695 | auto row_size = src_row_size + ddst_row_size; |
3696 | |
3697 | int h_block_size = jcp.oh; |
3698 | int h_last_block_size = h_block_size; |
3699 | int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad)); |
3700 | auto working_set_size = row_size * h_block_size; |
3701 | |
3702 | if (working_set_size > full_spat_max_working_set_size) { |
3703 | assert(full_spat_opt_working_set_size < full_spat_max_working_set_size); |
3704 | |
3705 | while (working_set_size > full_spat_opt_working_set_size |
3706 | && h_block_size >= min_h_block_size) { |
3707 | for (int i = 2; i <= h_block_size; i++) |
3708 | if (i == h_block_size) |
3709 | h_block_size = h_block_size / 2; |
3710 | else if (h_block_size % i == 0) { |
3711 | h_block_size = h_block_size / i; |
3712 | break; |
3713 | } |
3714 | working_set_size = row_size * h_block_size; |
3715 | } |
3716 | h_block_size = nstl::max(min_h_block_size, h_block_size); |
3717 | h_last_block_size = jcp.oh % h_block_size; |
3718 | if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size; |
3719 | } |
3720 | |
3721 | Opmask reg_h_block = k1; |
3722 | Reg64 reg_kh = rax; |
3723 | Reg64 reg_kw = rbx; |
3724 | Reg64 reg_tmp = abi_not_param1; |
3725 | Reg32 reg_tmp_w = reg_tmp.cvt32(); |
3726 | Reg64 reg_ohs = rdx; |
3727 | Reg64 reg_ihs = rsi; |
3728 | Reg64 reg_h = r8; |
3729 | Reg64 reg_i = r9; |
3730 | Reg64 reg_j = r10; |
3731 | |
3732 | Reg64 reg_src = r13; |
3733 | Reg64 reg_ddst = r14; |
3734 | Reg64 reg_ker = r15; |
3735 | |
3736 | Reg64 reg_src_save = abi_param1; |
3737 | Reg64 reg_ddst_save = reg_tmp; |
3738 | |
3739 | auto zmm_ddst = [&](int oi) { return Zmm(24 + oi % 8); }; |
3740 | auto zmm_ker = [&](int ic1) { return Zmm(ic1); }; |
3741 | auto src_addr = [&](int oi, int ic1) { |
3742 | return zword_b[reg_src + get_src_offset(ic1, oi)]; |
3743 | }; |
3744 | auto ddst_addr = [&](int oi) { |
3745 | auto ow_per_oc = 2; |
3746 | return ptr[reg_ddst + get_ddst_offset(ow_per_oc * oi)]; |
3747 | }; |
3748 | auto ker_addr |
3749 | = [&](int ic1) { return ptr[reg_ker + get_kernel_offset(ic1, 0)]; }; |
3750 | |
3751 | auto emit_block = [&]() { |
3752 | auto pad_ow = jcp.tr_ow; |
3753 | int ow_per_oc = 2; |
3754 | int def_step_size = 16; |
3755 | bool has_w_tail = pad_ow % def_step_size != 0; |
3756 | bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail; |
3757 | |
3758 | auto emit_step = [&](int ur_ow, bool is_w_tail) { |
3759 | int tail_size = pad_ow % ur_ow; |
3760 | int this_ur_ow = (is_w_tail && tail_size) ? tail_size : ur_ow; |
3761 | auto numloads = 1; |
3762 | |
3763 | assert(this_ur_ow % ow_per_oc == 0); |
3764 | int steps = this_ur_ow / ow_per_oc; |
3765 | for (int oi_base = 0; oi_base < steps; oi_base += numloads) { |
3766 | for (int oi_offset = 0; oi_offset < numloads; oi_offset++) { |
3767 | int oi = oi_base + oi_offset; |
3768 | if (oi < steps) { |
3769 | vmovups(zmm_ddst(oi), ddst_addr(oi)); |
3770 | } else { |
3771 | auto zmm = zmm_ddst(oi); |
3772 | vpxord(zmm, zmm, zmm); |
3773 | } |
3774 | } |
3775 | |
3776 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { |
3777 | vdpbf16ps(zmm_ker(ic1), zmm_ddst(oi_base), |
3778 | src_addr(ow_per_oc * oi_base, ic1)); |
3779 | } |
3780 | } |
3781 | }; |
3782 | |
3783 | if (full_w_unroll) { |
3784 | emit_step(pad_ow, true); |
3785 | } else { |
3786 | Label w_loop; |
3787 | int num_w_iters = pad_ow / def_step_size; |
3788 | mov(reg_i, num_w_iters); |
3789 | L(w_loop); |
3790 | { |
3791 | emit_step(def_step_size, false); |
3792 | add(reg_src, get_src_offset(0, def_step_size)); |
3793 | add(reg_ddst, get_ddst_offset(def_step_size)); |
3794 | sub(reg_i, 1); |
3795 | jnz(w_loop); |
3796 | } |
3797 | if (has_w_tail) { emit_step(def_step_size, true); } |
3798 | // reset reg_src and reg_ddst because emit_h_loop expects |
3799 | // unmodified pointers |
3800 | int w_offset = num_w_iters * def_step_size; |
3801 | sub(reg_src, get_src_offset(0, w_offset)); |
3802 | sub(reg_ddst, get_ddst_offset(w_offset)); |
3803 | } |
3804 | }; |
3805 | |
3806 | auto emit_h_loop = [&]() { |
3807 | Label h_loop, skip_h_loop; |
3808 | mov(reg_j, 1); |
3809 | cmp(reg_j, reg_h); |
3810 | je(skip_h_loop, T_NEAR); |
3811 | L(h_loop); |
3812 | { |
3813 | emit_block(); |
3814 | |
3815 | add(reg_src, get_src_offset(0, 0, 1)); |
3816 | add(reg_ddst, get_ddst_offset(0, 1)); |
3817 | add(reg_j, 1); |
3818 | cmp(reg_j, reg_h); |
3819 | jb(h_loop); |
3820 | } |
3821 | L(skip_h_loop); |
3822 | |
3823 | emit_block(); |
3824 | }; |
3825 | |
3826 | auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) { |
3827 | xor_(reg_kh, reg_kh); |
3828 | Label kh_loop, kh_loop_end; |
3829 | |
3830 | int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size; |
3831 | // NB: this is correct because we only support t_pad = kh / 2 and thus |
3832 | // ih == oh |
3833 | int ih_block_size = oh_block_size |
3834 | + (!is_first_block + !is_last_block) * jcp.t_pad; |
3835 | |
3836 | L(kh_loop); |
3837 | { |
3838 | if (is_first_block) { |
3839 | xor_(reg_tmp, reg_tmp); |
3840 | mov(reg_ohs, jcp.t_pad); |
3841 | sub(reg_ohs, reg_kh); |
3842 | cmovb(reg_ohs, reg_tmp); |
3843 | |
3844 | mov(reg_ihs, reg_ohs); |
3845 | sub(reg_ihs, jcp.t_pad); |
3846 | add(reg_ihs, reg_kh); |
3847 | } else { |
3848 | xor_(reg_ohs, reg_ohs); |
3849 | mov(reg_ihs, reg_kh); |
3850 | } |
3851 | |
3852 | mov(reg_tmp, oh_block_size); |
3853 | sub(reg_tmp, reg_ohs); |
3854 | mov(reg_h, ih_block_size); |
3855 | sub(reg_h, reg_ihs); |
3856 | cmp(reg_tmp, reg_h); |
3857 | cmovb(reg_h, reg_tmp); |
3858 | |
3859 | Label kh_loop_work; |
3860 | cmp(reg_h, 0); |
3861 | jg(kh_loop_work, T_NEAR); |
3862 | |
3863 | // empty h loop for this jcp.kh: |
3864 | // - set the ddst to 0 if necessary |
3865 | // - move ker pt |
3866 | // - jump to the end |
3867 | sub(reg_h, 1); |
3868 | Label skip_ker_zeroing; |
3869 | |
3870 | // The reg_ker ptr has highest bit set if the ddst needs to be |
3871 | // zeroed. Those who have byte-aligned their data will suffer the |
3872 | // consequences :( |
3873 | // TODO: move the flag to a mask register? (Roma) |
3874 | test(reg_ker, 1); |
3875 | jz(skip_ker_zeroing, T_NEAR); |
3876 | |
3877 | Label zeroing_loop; |
3878 | vpxord(zmm0, zmm0, zmm0); |
3879 | and_(reg_ker, ~1); // temporarily clear the zeroing flag |
3880 | mov(reg_tmp, jcp.kw); |
3881 | L(zeroing_loop); |
3882 | { |
3883 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) |
3884 | vmovups(ker_addr(ic1), zmm0); |
3885 | add(reg_ker, get_kernel_offset(jcp.ic_block, 0)); |
3886 | sub(reg_tmp, 1); |
3887 | jnz(zeroing_loop, T_NEAR); |
3888 | } |
3889 | // restore the zeroing flag (it will be cleared after the end of |
3890 | // emit_kh_kw_loop, but we may need it until then) |
3891 | or_(reg_ker, 1); |
3892 | jmp(kh_loop_end, T_NEAR); |
3893 | |
3894 | L(skip_ker_zeroing); |
3895 | add(reg_ker, get_kernel_offset(0, jcp.kw)); |
3896 | jmp(kh_loop_end, T_NEAR); |
3897 | |
3898 | L(kh_loop_work); |
3899 | |
3900 | mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1)); |
3901 | mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1)); |
3902 | |
3903 | add(reg_src, reg_ihs); |
3904 | add(reg_ddst, reg_ohs); |
3905 | |
3906 | Label kw_loop; |
3907 | xor_(reg_kw, reg_kw); |
3908 | L(kw_loop); |
3909 | { |
3910 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { |
3911 | auto zmm = zmm_ker(ic1); |
3912 | vpxord(zmm, zmm, zmm); |
3913 | } |
3914 | |
3915 | mov(reg_ddst_save, reg_ddst); |
3916 | mov(reg_src_save, reg_src); |
3917 | lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]); |
3918 | |
3919 | emit_h_loop(); |
3920 | |
3921 | mov(reg_ddst, reg_ddst_save); |
3922 | mov(reg_src, reg_src_save); |
3923 | |
3924 | Label do_store; |
3925 | // The reg_ker ptr has highest bit set if the ddst needs to |
3926 | // be zeroed. Those who have byte-aligned their data will |
3927 | // suffer the consiquences :( |
3928 | mov(reg_tmp, reg_ker); |
3929 | and_(reg_ker, ~1); |
3930 | test(reg_tmp, 1); |
3931 | jnz(do_store, T_NEAR); |
3932 | |
3933 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { |
3934 | auto zmm = zmm_ker(ic1); |
3935 | vaddps(zmm, ker_addr(ic1)); |
3936 | } |
3937 | |
3938 | L(do_store); |
3939 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { |
3940 | auto zmm = zmm_ker(ic1); |
3941 | vmovups(ker_addr(ic1), zmm); |
3942 | } |
3943 | |
3944 | mov(reg_ker, reg_tmp); |
3945 | add(reg_ker, get_kernel_offset(jcp.ic_block, 0)); |
3946 | add(reg_kw, 1); |
3947 | cmp(reg_kw, jcp.kw); |
3948 | jl(kw_loop); |
3949 | } |
3950 | |
3951 | sub(reg_src, reg_ihs); |
3952 | sub(reg_ddst, reg_ohs); |
3953 | |
3954 | L(kh_loop_end); |
3955 | add(reg_kh, 1); |
3956 | cmp(reg_kh, jcp.kh); |
3957 | jl(kh_loop); |
3958 | } |
3959 | }; |
3960 | |
3961 | mov(reg_src, ptr[param + GET_OFF(src)]); |
3962 | mov(reg_ddst, ptr[param + GET_OFF(dst)]); |
3963 | mov(reg_ker, ptr[param + GET_OFF(filt)]); |
3964 | mov(reg_tmp, ptr[param + GET_OFF(channel)]); |
3965 | or_(reg_ker, reg_tmp); |
3966 | |
3967 | bool single_kh_kw_loop = (h_last_block_size == jcp.oh); |
3968 | |
3969 | auto src_row_step = get_src_offset(0, 0, 1); |
3970 | auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad); |
3971 | auto ddst_block_step = get_ddst_offset(0, h_block_size); |
3972 | |
3973 | emit_kh_kw_loop(true, single_kh_kw_loop); |
3974 | |
3975 | if (!single_kh_kw_loop) { |
3976 | auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh); |
3977 | sub(reg_ker, ker_reset_offset); |
3978 | and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates |
3979 | |
3980 | add(reg_src, first_src_block_step); |
3981 | add(reg_ddst, ddst_block_step); |
3982 | |
3983 | int num_innermost_iters |
3984 | = (jcp.oh - h_last_block_size) / h_block_size - 1; |
3985 | if (num_innermost_iters > 0) { |
3986 | Label h_block_loop; |
3987 | |
3988 | mov(reg_tmp_w, num_innermost_iters); |
3989 | kmovw(reg_h_block, reg_tmp_w); |
3990 | L(h_block_loop); |
3991 | { |
3992 | emit_kh_kw_loop(false, false); |
3993 | sub(reg_ker, ker_reset_offset); |
3994 | add(reg_src, src_row_step * h_block_size); |
3995 | add(reg_ddst, ddst_block_step); |
3996 | |
3997 | kmovw(reg_tmp_w, reg_h_block); |
3998 | sub(reg_tmp_w, 1); |
3999 | kmovw(reg_h_block, reg_tmp_w); |
4000 | jnz(h_block_loop); |
4001 | } |
4002 | } |
4003 | |
4004 | emit_kh_kw_loop(false, true); |
4005 | } |
4006 | } |
4007 | |
4008 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_loop() { |
4009 | Reg64 reg_mask_load = r11; |
4010 | if (jcp.uses_permw_transposition) { |
4011 | |
4012 | mov(reg_mask_load.cvt32(), 0xffffffff); |
4013 | kmovd(m_ffffffff, reg_mask_load.cvt32()); |
4014 | |
4015 | mov(reg_mask_load.cvt32(), 0x0000ffff); |
4016 | kmovd(m_0000ffff, reg_mask_load.cvt32()); |
4017 | |
4018 | mov(reg_mask_load.cvt32(), 0xffff0000); |
4019 | kmovd(m_ffff0000, reg_mask_load.cvt32()); |
4020 | const int oc_tail = jcp.oc_tail; |
4021 | if (oc_tail) { |
4022 | mov(reg_mask_load.cvt32(), (1 << oc_tail) - 1); |
4023 | kmovd(m_0000_oc_tail, reg_mask_load.cvt32()); |
4024 | kshiftld(m_oc_tail_0000, m_0000_oc_tail, 16); |
4025 | } |
4026 | const int ic_tail = jcp.ic_tail; |
4027 | if (ic_tail) { |
4028 | mov(reg_mask_load.cvt32(), (1 << ic_tail) - 1); |
4029 | kmovd(m_0000_ic_tail, reg_mask_load.cvt32()); |
4030 | kshiftld(m_ic_tail_0000, m_0000_ic_tail, 16); |
4031 | } |
4032 | } else if (jcp.is_1stconv && !jcp.transpose_src) { |
4033 | if (jcp.stride_w == 1) { |
4034 | int ieveryother_mask = 0x55555555; |
4035 | mov(reg_mask_load.cvt32(), ieveryother_mask); |
4036 | kmovd(everyother_mask, reg_mask_load.cvt32()); |
4037 | kshiftld(everyother_shift_mask, everyother_mask, 1); |
4038 | } else { |
4039 | mov(reg_mask_load.cvt32(), 0xffffffff); |
4040 | kmovd(m_ffffffff, reg_mask_load.cvt32()); |
4041 | } |
4042 | } |
4043 | |
4044 | mov(reg_src, ptr[param + GET_OFF(src)]); |
4045 | mov(reg_ddst, ptr[param + GET_OFF(dst)]); |
4046 | mov(reg_kernel, ptr[param + GET_OFF(filt)]); |
4047 | |
4048 | maybe_zero_kernel(); |
4049 | maybe_compute_diff_bias(); |
4050 | |
4051 | switch (jcp.harness) { |
4052 | case harness_3d_reduction: compute_od_loop_common(true); break; |
4053 | case harness_2d_reduction: compute_oh_loop_common(true); break; |
4054 | case harness_mb_reduction: compute_oh_loop_common(); break; |
4055 | case harness_compute_full_spatial: compute_full_spat_loop(); break; |
4056 | default: assert(!"Invalid harness type" ); |
4057 | } |
4058 | } |
4059 | |
4060 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::setup_stack_space() { |
4061 | |
4062 | if ((jcp.is_1stconv && !jcp.transpose_src && jcp.stride_w > 1) |
4063 | || jcp.uses_permw_transposition) { |
4064 | int ur_w, ur_w_tail, ur_w_trips; |
4065 | get_ur_w(ur_w, ur_w_tail, ur_w_trips); |
4066 | ur_w = nstl::max(ur_w, ur_w_tail); |
4067 | ic_block_step_stack_size = jcp.uses_permw_transposition |
4068 | ? permw_stack_size(ur_w) |
4069 | : interleave_stack_size(ur_w, jcp.ic_block_step); |
4070 | } else |
4071 | ic_block_step_stack_size = extern_ic_block_step_stack_size; |
4072 | |
4073 | permw_buffer_start = 0; |
4074 | kd_count_offset = ic_block_step_stack_size; |
4075 | src_d_offset = ic_block_step_stack_size + 8; |
4076 | ddst_d_offset = ic_block_step_stack_size + 16; |
4077 | d_index_offset = ic_block_step_stack_size + 24; |
4078 | trans_tmp_offset = ic_block_step_stack_size + 32; |
4079 | ih_dilate_shift = ic_block_step_stack_size + 40; |
4080 | icb_loop_ker_ptr = ic_block_step_stack_size + 48; |
4081 | icb_loop_src_ptr = ic_block_step_stack_size + 56; |
4082 | stack_space_needed = ic_block_step_stack_size + 64; |
4083 | } |
4084 | |
4085 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::generate() { |
4086 | preamble(); |
4087 | |
4088 | setup_stack_space(); |
4089 | |
4090 | sub(rsp, stack_space_needed); |
4091 | |
4092 | compute_loop(); |
4093 | |
4094 | add(rsp, stack_space_needed); |
4095 | |
4096 | postamble(); |
4097 | |
4098 | if (jcp.uses_permw_transposition) { |
4099 | align(64); |
4100 | L(dst_prm_table); |
4101 | const uint16_t dst_prm_array[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, |
4102 | 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, |
4103 | 29, 14, 30, 15, 31}; |
4104 | |
4105 | for (size_t i = 0; i < 32; ++i) |
4106 | dw(dst_prm_array[i]); |
4107 | } |
4108 | } |
4109 | |
4110 | status_t jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf( |
4111 | jit_conv_conf_t &jcp, const convolution_desc_t &cd, |
4112 | memory_desc_t &src_md, memory_desc_t &diff_weights_md, |
4113 | memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) { |
4114 | const int simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
4115 | |
4116 | const memory_desc_wrapper src_d(&src_md); |
4117 | const memory_desc_wrapper diff_weights_d(&diff_weights_md); |
4118 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
4119 | const memory_desc_wrapper diff_bias_d(&diff_bias_md); |
4120 | |
4121 | const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; |
4122 | int ndims = src_d.ndims(); |
4123 | |
4124 | jcp = zero<decltype(jcp)>(); |
4125 | jcp.nthr = nthreads; |
4126 | jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16 |
4127 | : bf16_emulation_t::get_isa(); |
4128 | jcp.has_vnni = true; |
4129 | jcp.ndims = ndims; |
4130 | jcp.prop_kind = cd.prop_kind; |
4131 | |
4132 | jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; |
4133 | jcp.mb = src_d.dims()[0]; |
4134 | |
4135 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
4136 | jcp.oc_without_padding = jcp.oc; |
4137 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
4138 | |
4139 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
4140 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
4141 | jcp.iw = src_d.dims()[ndims - 1]; |
4142 | jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; |
4143 | jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2]; |
4144 | jcp.ow = diff_dst_d.dims()[ndims - 1]; |
4145 | |
4146 | jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; |
4147 | jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2]; |
4148 | jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1]; |
4149 | |
4150 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
4151 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
4152 | jcp.l_pad = cd.padding[0][ndims - 3]; |
4153 | |
4154 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
4155 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
4156 | jcp.stride_w = cd.strides[ndims - 3]; |
4157 | |
4158 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
4159 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
4160 | jcp.dilate_w = cd.dilates[ndims - 3]; |
4161 | |
4162 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
4163 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
4164 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
4165 | |
4166 | bool ok = true |
4167 | // general condition to simplify dilations |
4168 | && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1) |
4169 | && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1) |
4170 | && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1) |
4171 | // special condition to simplify dilations in compute_oh_loop_common |
4172 | && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih); |
4173 | if (!ok) return status::unimplemented; |
4174 | |
4175 | jcp.r_pad = nstl::max(0, |
4176 | calculate_end_padding( |
4177 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw)); |
4178 | jcp.b_pad = nstl::max(0, |
4179 | calculate_end_padding( |
4180 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh)); |
4181 | jcp.back_pad = nstl::max(0, |
4182 | calculate_end_padding( |
4183 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd)); |
4184 | |
4185 | /* XXX: no support for padding when dilation_d > 0 */ |
4186 | if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad))) |
4187 | return status::unimplemented; |
4188 | |
4189 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
4190 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
4191 | jcp.ohp = jcp.oh; |
4192 | jcp.owp = jcp.ow; |
4193 | jcp.aligned_threads = 0; |
4194 | |
4195 | jcp.simd_w = simd_w; |
4196 | jcp.oc_block = simd_w; |
4197 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
4198 | const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw); |
4199 | const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); |
4200 | auto curr_src_tag = src_d.matches_one_of_tag( |
4201 | dat_tag_nxc, dat_tag_nCx16c, dat_tag_ncx); |
4202 | auto curr_dst_tag |
4203 | = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
4204 | bool is_data_layout_nxc = IMPLICATION(curr_src_tag != dat_tag_nxc, |
4205 | src_d.format_kind() == format_kind::any) |
4206 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
4207 | diff_dst_d.format_kind() == format_kind::any) |
4208 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
4209 | |
4210 | jcp.is_1stconv = is_1stconv(jcp); |
4211 | |
4212 | bool ok_to_pad_channels |
4213 | = (jcp.ngroups == 1) && !jcp.is_1stconv && !is_data_layout_nxc; |
4214 | |
4215 | if (ok_to_pad_channels) { |
4216 | jcp.oc = rnd_up(jcp.oc, simd_w); |
4217 | jcp.ic = rnd_up(jcp.ic, simd_w); |
4218 | } |
4219 | |
4220 | auto src_tag = is_data_layout_nxc |
4221 | ? dat_tag_nxc |
4222 | : (jcp.is_1stconv ? dat_tag_ncx : dat_tag_nCx16c); |
4223 | auto dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
4224 | auto wei_tag = jcp.is_1stconv |
4225 | ? pick(2 * ndims - 6 + with_groups, Owi16o, gOwi16o, Ohwi16o, |
4226 | gOhwi16o, Odhwi16o, gOdhwi16o) |
4227 | : pick(2 * ndims - 6 + with_groups, OIw16i16o, gOIw16i16o, |
4228 | OIhw16i16o, gOIhw16i16o, OIdhw16i16o, gOIdhw16i16o); |
4229 | |
4230 | if (src_md.format_kind == format_kind::any) { |
4231 | CHECK(memory_desc_init_by_tag(src_md, src_tag)); |
4232 | } else if (curr_src_tag != src_tag) |
4233 | return status::unimplemented; |
4234 | jcp.src_tag = src_tag; |
4235 | |
4236 | if (diff_dst_md.format_kind == format_kind::any) { |
4237 | CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag)); |
4238 | } else if (curr_dst_tag != dst_tag) |
4239 | return status::unimplemented; |
4240 | jcp.dst_tag = dst_tag; |
4241 | |
4242 | if (diff_weights_md.format_kind == format_kind::any) { |
4243 | CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); |
4244 | jcp.wei_tag = wei_tag; |
4245 | } else { |
4246 | jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); |
4247 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
4248 | } |
4249 | |
4250 | /* conditions on bias memory */ |
4251 | jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; |
4252 | if (jcp.with_bias) { |
4253 | if (diff_bias_d.format_kind() == format_kind::any) |
4254 | CHECK(memory_desc_init_by_tag(diff_bias_md, x)); |
4255 | } |
4256 | jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef; |
4257 | jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0; |
4258 | |
4259 | jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block); |
4260 | |
4261 | /* kernel applicability check wrt boundaries |
4262 | * the conditions are quite general across the kernels we have, |
4263 | * but ideally the check should belong to a specific kernel... */ |
4264 | const int max_pad_h = ext_kh / 2; |
4265 | const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw |
4266 | && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h |
4267 | && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd |
4268 | && IMPLICATION(jcp.is_1stconv && jcp.ow > max_ur_w, |
4269 | jcp.l_pad < max_ur_w && ext_kw <= jcp.ow); |
4270 | if (!boundaries_ok) return status::unimplemented; |
4271 | |
4272 | const int max_kw = jcp.is_1stconv ? 24 : 14; |
4273 | /* yet another common check */ |
4274 | if (jcp.kw > max_kw) return status::unimplemented; |
4275 | |
4276 | jcp.wei_dt = diff_weights_d.data_type(); |
4277 | |
4278 | jcp.ic_block = jcp.is_1stconv ? jcp.ic : simd_w; |
4279 | if (ok_to_pad_channels) jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
4280 | jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block); |
4281 | ok = true && one_of(ndims, 3, 4, 5) |
4282 | && everyone_is( |
4283 | data_type::bf16, src_d.data_type(), diff_dst_d.data_type()) |
4284 | && one_of(diff_weights_d.data_type(), data_type::f32, |
4285 | data_type::bf16); |
4286 | if (!ok) return status::unimplemented; |
4287 | |
4288 | jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.ic_block : 0; |
4289 | jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.oc_block : 0; |
4290 | |
4291 | if (jcp.is_1stconv) { |
4292 | jcp.ic_block_step = 24 / jcp.kw; |
4293 | while (jcp.ic_block % jcp.ic_block_step != 0) |
4294 | jcp.ic_block_step--; |
4295 | } else { |
4296 | jcp.ic_block_step |
4297 | = jcp.kw <= 3 ? 8 : (jcp.kw < 7 ? 4 : (jcp.kw <= 12 ? 2 : 1)); |
4298 | } |
4299 | |
4300 | // jcp.uses_permw_transposition = false shows better performance for |
4301 | // resnet50 v1.5 problems |
4302 | // jcp.uses_permw_transposition = true works better for 3d 1x1x1 problems |
4303 | const bool is_permw_applicable |
4304 | = !jcp.is_1stconv && jcp.stride_w == 1 && jcp.dilate_w == 0; |
4305 | const bool apply_permw_blocked = !is_data_layout_nxc && ndims == 5 |
4306 | && jcp.kw == 1 && jcp.ic_block_step > 4; |
4307 | // Threshold is based on performance measurements |
4308 | const bool apply_permw_nxc = is_data_layout_nxc && ndims == 3 |
4309 | && nstl::max(jcp.ic, jcp.oc) <= 32; |
4310 | jcp.uses_permw_transposition |
4311 | = is_permw_applicable && (apply_permw_blocked || apply_permw_nxc); |
4312 | |
4313 | jcp.kernel_kind = embd_bcast; |
4314 | if (jcp.uses_permw_transposition && jcp.kw <= 3) |
4315 | jcp.kernel_kind = expl_bcast; |
4316 | if (jcp.uses_permw_transposition && jcp.kernel_kind == expl_bcast) |
4317 | jcp.ic_block_step = jcp.kw <= 3 ? 4 : (jcp.kw < 7 ? 2 : 1); |
4318 | |
4319 | if (jcp.uses_permw_transposition) { |
4320 | jcp.transpose_src = false; |
4321 | jcp.transpose_dst = false; |
4322 | } else if (jcp.is_1stconv && IMPLICATION(is_data_layout_nxc, jcp.ic == 1)) { |
4323 | jcp.transpose_src = false; |
4324 | jcp.transpose_dst = true; |
4325 | } else { |
4326 | jcp.transpose_src = true; |
4327 | jcp.transpose_dst = true; |
4328 | } |
4329 | |
4330 | const bool is_2d = (ndims == 4); |
4331 | const bool is_3d = (ndims == 5); |
4332 | jcp.typesize_in = sizeof(bfloat16_t); |
4333 | jcp.typesize_out = sizeof(float); |
4334 | const dim_t cache_l2 |
4335 | = platform::get_per_core_cache_size(2) / jcp.typesize_out; |
4336 | |
4337 | // Observation: Given large 3D shapes with large filter size, 1st nspc |
4338 | // bwd_w convolution benefits from non-temporal stores in diff_dst |
4339 | // transformation but not so much from blocking w.r.t. depth dimension |
4340 | // In particular, it's optimized for i3D 1st convolution |
4341 | const bool nt_stores_ok = is_data_layout_nxc |
4342 | && dim_t(jcp.oc) * jcp.od * jcp.oh * jcp.ow >= 2 * cache_l2 |
4343 | && jcp.kd >= 6 && jcp.kh >= 6 && jcp.kw >= 6; |
4344 | |
4345 | // Performancewise transposition of diff_dst tensor is one of the major |
4346 | // bottleneck in 1st convolution. Thus for large diff_dst size we can |
4347 | // potentially further split up transposition in smaller chunks to achieve |
4348 | // better cache reuse |
4349 | const bool large_diff_dst_size |
4350 | = dim_t(jcp.oc) * jcp.od * jcp.oh * jcp.ow >= cache_l2; |
4351 | |
4352 | // For two dimensional diff_dst tensor blocking along height demands |
4353 | // non-trivial work along width dimension. Similarly, for three dimensional |
4354 | // diff_dst tensor enough work must be present in the joint width-height |
4355 | // dimension. Finally, there is no blocking along the width dimension |
4356 | const bool blocking_ok = large_diff_dst_size |
4357 | && IMPLICATION(is_2d, jcp.ow >= 124 && jcp.oh > 1) |
4358 | && IMPLICATION(is_3d, jcp.ow * jcp.oh >= 64 * 124 && jcp.od > 1) |
4359 | && (is_2d || is_3d); |
4360 | |
4361 | // TODO: Find more shapes (especially 3D with large spatials) for which |
4362 | // local transposition will be beneficial. Furthermore, for TBB threads |
4363 | // more shapes can potentially benefit from spatial blocking |
4364 | bool use_spatial_blocking = jcp.is_1stconv && !nt_stores_ok && blocking_ok; |
4365 | int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow; |
4366 | if (use_spatial_blocking) { |
4367 | // Default value, works best most of the times |
4368 | // TODO: For 3D shapes with intermediate sizes especially the ones not |
4369 | // belonging to the 1st convolution, we potentially have more scope |
4370 | // for optimization |
4371 | optimal_blk_size = 1; |
4372 | |
4373 | // Diff_weights computation can be roughly broken down into |
4374 | // the following three steps |
4375 | // = [Src transform*] + [Diff_dst transform] + [Weights computation] |
4376 | // |
4377 | // where the bottleneck lies with diff_dst transform that spatial |
4378 | // blocking tries to mitigate by avoiding cache thrashing. |
4379 | // *note: Src transform may not always be needed. |
4380 | // |
4381 | // In an idealistic scenario, optimal_blk_size will be an explicit |
4382 | // function of the following form |
4383 | // optimal_blk_size = f(od, oh, ow, oc) |
4384 | // |
4385 | // though owing to lack of data points w.r.t. 1st convolution shapes it |
4386 | // is approximated by one with few exceptional cases [found by manual |
4387 | // optimization] as written below |
4388 | |
4389 | if (is_2d && utils::one_of(jcp.oh, 149, 300, 224, 512, 608)) { |
4390 | switch (jcp.oh) { |
4391 | case 149: optimal_blk_size = 10; break; |
4392 | case 224: optimal_blk_size = 56; break; |
4393 | case 300: optimal_blk_size = 30; break; |
4394 | case 512: optimal_blk_size = 8; break; |
4395 | case 608: optimal_blk_size = 10; break; |
4396 | } |
4397 | } |
4398 | } |
4399 | |
4400 | jcp.global_transpose = dnnl_thr_syncable() && !use_spatial_blocking; |
4401 | jcp.use_nt_stores_ddst = jcp.global_transpose && nt_stores_ok; |
4402 | jcp.spatial_blk_size = optimal_blk_size; |
4403 | |
4404 | const bool padding_ok = IMPLICATION(!jcp.transpose_src, |
4405 | jcp.l_pad < max_ur_w && jcp.r_pad < max_ur_w |
4406 | && ext_kw <= jcp.iw + 1); |
4407 | if (!padding_ok) return status::unimplemented; |
4408 | |
4409 | const int tr_round = 2; |
4410 | // Logic for tr_pad calculation: transpose is used in the extern kernel. |
4411 | // There is a memory usage optimization where physical padding is shared |
4412 | // between transpose buffers. In calculating on a row, data is read from the |
4413 | // src 2 elements at a time due to the bf16 broadcast. Calculation starts |
4414 | // at the beginning of the left padding and ends at the end of the right |
4415 | // padding. Because elements are read two at a time, we may need r_pad + 1 |
4416 | // padding on the right. As such, the shared padding is the max of l_pad and |
4417 | // r_pad + 1, rounded as necessary for the transpose data format. |
4418 | int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round); |
4419 | jcp.tr_iw = jcp.transpose_src |
4420 | ? rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round) |
4421 | * jcp.stride_w |
4422 | : jcp.iw; |
4423 | |
4424 | jcp.tr_src_num_guard_elems = tr_pad; // upper bound |
4425 | jcp.tr_ow = jcp.transpose_dst ? rnd_up(jcp.ow, 2) : jcp.ow; |
4426 | |
4427 | bool args_ok = true |
4428 | && IMPLICATION(!is_data_layout_nxc, |
4429 | jcp.ic % jcp.ic_block == 0 && jcp.oc % jcp.oc_block == 0) |
4430 | && jcp.ic <= src_d.padded_dims()[1] |
4431 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
4432 | && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] |
4433 | && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; |
4434 | if (!args_ok) return status::unimplemented; |
4435 | |
4436 | int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in; |
4437 | int out_row_size = jcp.oc_block * jcp.tr_ow * jcp.typesize_in; |
4438 | int full_spat_min_h_block_size |
4439 | = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad)); |
4440 | int full_spat_working_set_size |
4441 | = (inp_row_size + out_row_size) * full_spat_min_h_block_size; |
4442 | bool use_full_spat_loop = isa_has_bf16(jcp.isa) && jcp.ndims < 5 |
4443 | && jcp.ih == jcp.oh && jcp.iw == jcp.ow |
4444 | && !one_of(1, jcp.kh, jcp.kw) |
4445 | && everyone_is(1, jcp.stride_h, jcp.stride_w) |
4446 | && everyone_is(0, jcp.dilate_h, jcp.dilate_w) |
4447 | && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2 |
4448 | && !jcp.uses_permw_transposition && !jcp.is_1stconv |
4449 | && full_spat_working_set_size <= full_spat_opt_working_set_size |
4450 | && jcp.ic >= 128; |
4451 | |
4452 | jcp.harness = ndims == 5 |
4453 | ? harness_3d_reduction |
4454 | : (use_full_spat_loop ? harness_compute_full_spatial |
4455 | : (ndims == 4) ? harness_2d_reduction |
4456 | : harness_mb_reduction); |
4457 | |
4458 | switch (jcp.harness) { |
4459 | case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break; |
4460 | case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break; |
4461 | case harness_compute_full_spatial: |
4462 | case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break; |
4463 | default: assert(!"Invalid harness" ); jcp.nthr_mb_work = jcp.mb; |
4464 | } |
4465 | { // balancing |
4466 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; |
4467 | balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b); |
4468 | jcp.nthr = nthr; |
4469 | jcp.nthr_mb = nthr_mb; |
4470 | jcp.nthr_g = nthr_g; |
4471 | jcp.nthr_oc_b = nthr_oc_b; |
4472 | jcp.nthr_ic_b = nthr_ic_b; |
4473 | |
4474 | // TODO: Optimize memory allocation when threaded on height and depth |
4475 | if (jcp.transpose_src) { |
4476 | jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id; |
4477 | jcp.tr_src_buf_count = jcp.global_transpose |
4478 | ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups |
4479 | : jcp.nthr; |
4480 | } |
4481 | if (jcp.transpose_dst) { |
4482 | jcp.tr_diff_dst_buf_size |
4483 | = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od; |
4484 | jcp.tr_diff_dst_buf_count = jcp.global_transpose |
4485 | ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups |
4486 | : jcp.nthr; |
4487 | } |
4488 | } |
4489 | |
4490 | jcp.nb_ic_blocking_max = 1; |
4491 | if (is_data_layout_nxc && jcp.uses_permw_transposition |
4492 | && (jcp.ow > max_ur_w || jcp.ndims == 5)) |
4493 | jcp.nb_ic_blocking_max = nstl::min(8, div_up(jcp.nb_ic, jcp.nthr_ic_b)); |
4494 | return status::success; |
4495 | } |
4496 | |
4497 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_scratchpad( |
4498 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
4499 | |
4500 | if (!jcp.uses_permw_transposition) { |
4501 | // XXX: See the comment about tr_iw and guarding elements in |
4502 | // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf() |
4503 | const size_t tr_src_size = jcp.tr_src_buf_count * jcp.tr_src_buf_size |
4504 | + jcp.tr_src_num_guard_elems; |
4505 | scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in); |
4506 | |
4507 | /* prepare synchronization contexts */ |
4508 | if (jcp.global_transpose && jcp.nthr_oc_b > 1) { |
4509 | const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; |
4510 | scratchpad.book<simple_barrier::ctx_t>( |
4511 | key_conv_tr_src_bctx, tr_src_bctx_size); |
4512 | } |
4513 | |
4514 | const size_t tr_diff_dst_size |
4515 | = jcp.tr_diff_dst_buf_count * jcp.tr_diff_dst_buf_size; |
4516 | |
4517 | const size_t min_align = jcp.use_nt_stores_ddst ? 64 : jcp.typesize_in; |
4518 | scratchpad.book(key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in, |
4519 | min_align); |
4520 | |
4521 | /* prepare synchronization contexts */ |
4522 | if (jcp.global_transpose && jcp.nthr_ic_b > 1) { |
4523 | const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b; |
4524 | scratchpad.book<simple_barrier::ctx_t>( |
4525 | key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size); |
4526 | } |
4527 | } |
4528 | |
4529 | if (IMPLICATION(jcp.nthr_mb == 1, |
4530 | (jcp.with_bias && jcp.bia_dt == data_type::bf16) |
4531 | || jcp.wei_dt == data_type::bf16)) { |
4532 | const size_t wei_size = static_cast<size_t>(jcp.ngroups) * jcp.nb_oc |
4533 | * jcp.oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw |
4534 | * jcp.kd; |
4535 | const size_t bia_size |
4536 | = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
4537 | |
4538 | const int num_wei_buffers |
4539 | = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1; |
4540 | const int num_bia_buffers = jcp.with_bias |
4541 | ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb |
4542 | : jcp.nthr_mb - 1) |
4543 | : 0; |
4544 | |
4545 | const size_t wei_bia_reduction_size |
4546 | = wei_size * num_wei_buffers + bia_size * num_bia_buffers; |
4547 | |
4548 | scratchpad.book<float>( |
4549 | key_conv_wei_bia_reduction, wei_bia_reduction_size); |
4550 | |
4551 | if (jcp.global_transpose) |
4552 | scratchpad.book<simple_barrier::ctx_t>( |
4553 | key_conv_wei_bia_reduction_bctx, 1); |
4554 | } |
4555 | |
4556 | if (jcp.with_bias) { |
4557 | if ((jcp.oc_without_padding % jcp.oc_block != 0) |
4558 | && jcp.bia_dt == data_type::f32) |
4559 | scratchpad.book(key_conv_padded_bias, |
4560 | jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia); |
4561 | } |
4562 | } |
4563 | |
4564 | void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::balance( |
4565 | const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_, |
4566 | int &nthr_oc_b_, int &nthr_ic_b_) { |
4567 | nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; |
4568 | |
4569 | const int max_threads = dnnl_get_max_threads(); |
4570 | |
4571 | if (max_threads < j.ngroups) { |
4572 | /* simplification... fortunately it doesn't hurt much */ |
4573 | nthr_ = nthr_g_ = max_threads; |
4574 | return; |
4575 | } |
4576 | |
4577 | nthr_g_ = j.ngroups; |
4578 | const int nthr = max_threads / nthr_g_; |
4579 | |
4580 | auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { |
4581 | /* calculate per thread memory cost (read/write). high level optimizer |
4582 | * tries to minimize memory consumption. few notes: |
4583 | * (n1) if weights tensor size is less than source and destination |
4584 | * tensors we apply the ratio of the source and destination |
4585 | * tensor sizes to weights one as compensation coefficient to |
4586 | * avoid parallelization across batch size only, othervise we |
4587 | * apply additional coefficient to source component based on |
4588 | * performance measurements |
4589 | * (n2) use scales based on output vs input channels ratio for source |
4590 | * and destination componets to imporve threading balance across |
4591 | * input and output channels */ |
4592 | |
4593 | const dim_t src_type_size = 2; |
4594 | const dim_t wei_type_size = 4; |
4595 | |
4596 | dim_t src_size |
4597 | = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size; |
4598 | dim_t dst_size |
4599 | = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size; |
4600 | dim_t wei_size |
4601 | = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size; |
4602 | |
4603 | float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size; |
4604 | float oi_channels_ratio = (float)j.nb_oc / j.nb_ic; |
4605 | auto get_src_coef = [=]() { |
4606 | float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f); |
4607 | if (wei_compensation_scale < 1.0f) src_coef *= 4.0f; |
4608 | |
4609 | return src_coef; |
4610 | }; |
4611 | |
4612 | auto get_dst_coef |
4613 | = [=]() { return nstl::max(oi_channels_ratio, 1.0f); }; |
4614 | |
4615 | auto get_wei_coef |
4616 | = [=]() { return nstl::max(wei_compensation_scale, 1.0f); }; |
4617 | |
4618 | const float src_coef = get_src_coef(); |
4619 | const float dst_coef = get_dst_coef(); |
4620 | const float wei_coef = get_wei_coef(); |
4621 | |
4622 | float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb) |
4623 | * div_up(j.ngroups, nthr_g_) * div_up(j.nb_ic, nthr_ic_b) * j.mb |
4624 | * j.ic_block * j.id * j.ih * j.tr_iw / j.nthr_mb_work |
4625 | / j.stride_d / j.stride_h / j.stride_w; |
4626 | float wei_v = wei_coef * div_up(j.ngroups, nthr_g_) |
4627 | * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) * j.kh |
4628 | * j.kw * j.kd * j.ic_block * j.oc_block; |
4629 | float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb) |
4630 | * div_up(j.ngroups, nthr_g_) * div_up(j.nb_oc, nthr_oc_b) * j.mb |
4631 | * j.oc_block * j.od * j.oh * j.tr_ow / j.nthr_mb_work; |
4632 | |
4633 | return src_v + dst_v + wei_v; |
4634 | }; |
4635 | |
4636 | float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); |
4637 | |
4638 | /* find the best thread distribution with lowest memory cost */ |
4639 | const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work); |
4640 | for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { |
4641 | const int nthr_par = nthr / nthr_mb; |
4642 | const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); |
4643 | for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { |
4644 | int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); |
4645 | |
4646 | float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); |
4647 | if (mem_cost <= best_mem_cost) { |
4648 | best_mem_cost = mem_cost; |
4649 | nthr_mb_ = nthr_mb; |
4650 | nthr_oc_b_ = nthr_oc_b; |
4651 | nthr_ic_b_ = nthr_ic_b; |
4652 | } |
4653 | } |
4654 | } |
4655 | |
4656 | if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr) |
4657 | nthr_mb_ = nstl::min(j.nthr_mb_work, nthr); |
4658 | nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; |
4659 | |
4660 | assert(nthr_ <= max_threads); |
4661 | } |
4662 | |
4663 | template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Zmm>; |
4664 | template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Ymm>; |
4665 | template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Xmm>; |
4666 | template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Zmm>; |
4667 | template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Ymm>; |
4668 | template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Xmm>; |
4669 | } // namespace x64 |
4670 | } // namespace cpu |
4671 | } // namespace impl |
4672 | } // namespace dnnl |
4673 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
4674 | |