1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/cpu_convolution_pd.hpp" |
18 | |
19 | #include "cpu/x64/jit_uni_dw_conv_kernel_utils.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace cpu { |
24 | namespace x64 { |
25 | |
26 | using namespace data_type; |
27 | |
28 | template <cpu_isa_t isa, data_type_t kernel_dt> |
29 | status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf( |
30 | jit_conv_conf_t &jcp, const convolution_desc_t &cd, |
31 | memory_desc_t &src_md, memory_desc_t &weights_md, |
32 | memory_desc_t &bias_md, memory_desc_t &dst_md, primitive_attr_t &attr) { |
33 | |
34 | using namespace dnnl::impl::format_tag; |
35 | using namespace dnnl::impl::utils; |
36 | |
37 | const memory_desc_wrapper src_d(&src_md); |
38 | const memory_desc_wrapper weights_d(&weights_md); |
39 | const memory_desc_wrapper dst_d(&dst_md); |
40 | const memory_desc_wrapper bias_d(&bias_md); |
41 | |
42 | const int ndims = src_d.ndims(); |
43 | // Currently this kernel only supports 2D convolutions. |
44 | if (ndims != 4) return status::unimplemented; |
45 | |
46 | jcp.prop_kind = cd.prop_kind; |
47 | |
48 | const auto blocked_tag = isa == avx512_core ? nChw16c : nChw8c; |
49 | const auto wei_tag = isa == avx512_core ? Goihw16g : Goihw8g; |
50 | const auto nxc_tag = nhwc; |
51 | const auto def_tag |
52 | = (mayiuse(avx512_core) |
53 | && jcp.prop_kind == prop_kind::forward_inference) |
54 | ? nxc_tag |
55 | : blocked_tag; |
56 | |
57 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
58 | |
59 | if (src_d.format_kind() == format_kind::any) { |
60 | CHECK(memory_desc_init_by_tag(src_md, def_tag)); |
61 | jcp.src_tag = def_tag; |
62 | } else { |
63 | jcp.src_tag = src_d.matches_one_of_tag(blocked_tag, nxc_tag); |
64 | } |
65 | |
66 | if (weights_d.format_kind() == format_kind::any) { |
67 | CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); |
68 | jcp.wei_tag = wei_tag; |
69 | } else { |
70 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
71 | } |
72 | |
73 | if (dst_d.format_kind() == format_kind::any) { |
74 | CHECK(memory_desc_init_by_tag(dst_md, def_tag)); |
75 | jcp.dst_tag = def_tag; |
76 | } else { |
77 | jcp.dst_tag = dst_d.matches_one_of_tag(blocked_tag, nxc_tag); |
78 | } |
79 | |
80 | if (jcp.with_bias) { |
81 | if (bias_d.format_kind() == format_kind::any) |
82 | CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); |
83 | } |
84 | |
85 | if (jcp.dst_tag != jcp.src_tag) return status::unimplemented; |
86 | const auto data_tag = jcp.src_tag; |
87 | const bool is_data_layout_nxc = data_tag == nxc_tag; |
88 | |
89 | const bool is_bf16 = src_d.data_type() == data_type::bf16; |
90 | |
91 | jcp.dst_dt = cd.dst_desc.data_type; |
92 | jcp.isa = (is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 : isa; |
93 | |
94 | if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core))) |
95 | return status::unimplemented; |
96 | |
97 | const int simd_w = isa == avx512_core ? 16 : 8; |
98 | |
99 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
100 | if (!with_groups) return status::unimplemented; |
101 | |
102 | jcp.ngroups = weights_d.dims()[0]; |
103 | jcp.mb = src_d.dims()[0]; |
104 | |
105 | jcp.oc = dst_d.dims()[1]; |
106 | jcp.oc_without_padding = jcp.oc; |
107 | jcp.ic = src_d.dims()[1]; |
108 | |
109 | jcp.ih = src_d.dims()[2]; |
110 | jcp.iw = src_d.dims()[3]; |
111 | jcp.oh = dst_d.dims()[2]; |
112 | jcp.ow = dst_d.dims()[3]; |
113 | |
114 | jcp.kh = weights_d.dims()[3]; |
115 | jcp.kw = weights_d.dims()[4]; |
116 | |
117 | jcp.t_pad = cd.padding[0][0]; |
118 | jcp.l_pad = cd.padding[0][1]; |
119 | |
120 | jcp.stride_h = cd.strides[0]; |
121 | jcp.stride_w = cd.strides[1]; |
122 | |
123 | jcp.dilate_h = cd.dilates[0]; |
124 | jcp.dilate_w = cd.dilates[1]; |
125 | |
126 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
127 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
128 | jcp.r_pad = calculate_end_padding( |
129 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
130 | jcp.b_pad = calculate_end_padding( |
131 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
132 | bool kernel_outside_src = false || ext_kw <= jcp.l_pad |
133 | || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad |
134 | || ext_kh <= jcp.b_pad; |
135 | if (kernel_outside_src) return status::unimplemented; |
136 | |
137 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
138 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
139 | |
140 | jcp.loop_order = loop_ngcw; |
141 | |
142 | jcp.ur_w = is_bf16 ? (isa_has_bf16(jcp.isa) ? 6 : 4) |
143 | : isa == avx512_core ? 6 : isa == avx2 ? 4 : 3; |
144 | jcp.ur_w = nstl::min(jcp.ur_w, jcp.ow); |
145 | |
146 | jcp.ch_block = simd_w; |
147 | jcp.nb_ch = div_up(jcp.oc, jcp.ch_block); |
148 | jcp.nb_ch_blocking = isa == avx512_core ? 4 : isa == avx2 ? 3 : 2; |
149 | if (jcp.nb_ch < jcp.nb_ch_blocking) jcp.nb_ch_blocking = jcp.nb_ch; |
150 | |
151 | if (is_data_layout_nxc) { |
152 | jcp.loop_order = loop_nhwcg; |
153 | const int resrc_depthwise_ur_w = (31 - jcp.kw + jcp.stride_w) |
154 | / (jcp.nb_ch_blocking + jcp.stride_w); |
155 | jcp.is_resrc_depthwise = (!is_bf16) && isa == avx512_core |
156 | && jcp.stride_w < jcp.kw && jcp.kw <= 5 && jcp.dilate_w == 0 |
157 | && resrc_depthwise_ur_w >= 2; |
158 | if (jcp.is_resrc_depthwise) { |
159 | jcp.ur_w = nstl::min(jcp.ow, resrc_depthwise_ur_w); |
160 | } |
161 | bool cache_aliasing |
162 | = (jcp.ngroups * jcp.iw * jcp.typesize_in) % 1024 == 0; |
163 | if (cache_aliasing) { |
164 | // currently only tuned for mobilenet-v1 shapes |
165 | const int limit = jcp.ow > 7 ? 7 : 4; |
166 | jcp.ur_w = nstl::min(jcp.ur_w, limit); |
167 | } |
168 | } else { |
169 | const size_t max_ch_off |
170 | = static_cast<size_t>(jcp.nb_ch_blocking - 1) * jcp.ch_block; |
171 | constexpr size_t max_ex_off |
172 | = isa == sse41 ? 4 : 0; // extra offset from repeats |
173 | |
174 | // check that input offsets fit into s32 |
175 | const size_t max_ic_off = max_ch_off * jcp.ih * jcp.iw; |
176 | const size_t max_iw_idx |
177 | = static_cast<size_t>(jcp.ur_w - 1) * jcp.stride_w |
178 | + (ext_kw - 1); |
179 | const size_t max_iw_off = max_iw_idx * jcp.ch_block; |
180 | const size_t max_input_offset |
181 | = (max_ic_off + max_iw_off + max_ex_off) * jcp.typesize_in; |
182 | if (max_input_offset > INT_MAX) return status::unimplemented; |
183 | |
184 | // check that output offsets fit into s32 |
185 | const size_t max_oc_off = max_ch_off * jcp.oh * jcp.ow; |
186 | const size_t max_ow_off |
187 | = static_cast<size_t>(jcp.ur_w - 1) * jcp.ch_block; |
188 | const size_t max_output_offset |
189 | = (max_oc_off + max_ow_off + max_ex_off) * jcp.typesize_out; |
190 | if (max_output_offset > INT_MAX) return status::unimplemented; |
191 | } |
192 | |
193 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
194 | |
195 | int r_pad_no_tail = nstl::max(0, |
196 | calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw, |
197 | jcp.stride_w, ext_kw)); |
198 | if (jcp.l_pad > jcp.ur_w || r_pad_no_tail > jcp.ur_w) |
199 | return status::unimplemented; |
200 | |
201 | CHECK(attr.set_default_formats(&dst_md)); |
202 | |
203 | const auto &post_ops = attr.post_ops_; |
204 | |
205 | jcp.with_sum = post_ops.find(primitive_kind::sum) != -1; |
206 | const int eltwise_ind = post_ops.find(primitive_kind::eltwise); |
207 | jcp.with_eltwise = eltwise_ind != -1; |
208 | if (jcp.with_eltwise) jcp.eltwise = post_ops.entry_[eltwise_ind].eltwise; |
209 | const int binary_ind = post_ops.find(primitive_kind::binary); |
210 | jcp.with_binary = binary_ind != -1; |
211 | if (jcp.with_binary) { |
212 | using namespace dnnl::impl::cpu::binary_injector_utils; |
213 | std::tie(jcp.with_binary_per_oc_bcast, jcp.with_binary_no_bcast) |
214 | = bcast_strategies_present_tup(post_ops.entry_, dst_d, |
215 | broadcasting_strategy_t::per_oc, |
216 | broadcasting_strategy_t::no_broadcast); |
217 | } |
218 | |
219 | jcp.post_ops = post_ops; |
220 | |
221 | using namespace injector; |
222 | static constexpr bool sum_at_pos_0_only = true; |
223 | static constexpr bool sum_requires_scale_one = true; |
224 | const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum}, |
225 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one}); |
226 | if (!post_ops_ok_) return status::unimplemented; |
227 | |
228 | const bool ok_to_pad_channels = true && !is_data_layout_nxc |
229 | && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups |
230 | && one_of(isa, avx512_core, avx2); |
231 | if (ok_to_pad_channels) { |
232 | jcp.oc = rnd_up(jcp.oc, simd_w); |
233 | jcp.ic = rnd_up(jcp.oc, simd_w); |
234 | jcp.ngroups = rnd_up(jcp.ngroups, simd_w); |
235 | } |
236 | |
237 | const bool args_ok = true && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups |
238 | && IMPLICATION(!is_data_layout_nxc, jcp.ngroups % simd_w == 0) |
239 | && jcp.wei_tag == wei_tag && data_tag != format_tag::undef |
240 | && jcp.ic <= src_d.padded_dims()[1] |
241 | && jcp.oc <= dst_d.padded_dims()[1] |
242 | && jcp.ngroups <= weights_d.padded_dims()[0]; |
243 | if (!args_ok) return status::unimplemented; |
244 | |
245 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
246 | |
247 | return status::success; |
248 | } |
249 | |
250 | template <cpu_isa_t isa, data_type_t kernel_dt> |
251 | void jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_scratchpad( |
252 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
253 | using namespace dnnl::impl::memory_tracking::names; |
254 | if (jcp.bia_dt == data_type::bf16) |
255 | scratchpad.book<float>(key_conv_bias_bf16_convert_wsp, jcp.oc); |
256 | else if (jcp.with_bias && jcp.oc_without_padding != jcp.oc) |
257 | scratchpad.book<float>(key_conv_padded_bias, jcp.oc); |
258 | } |
259 | |
260 | template <cpu_isa_t isa, data_type_t kernel_dt> |
261 | status_t jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_conf( |
262 | jit_conv_conf_t &jcp, const convolution_desc_t &cd, |
263 | memory_desc_t &diff_src_md, memory_desc_t &weights_md, |
264 | memory_desc_t &diff_dst_md) { |
265 | using namespace dnnl::impl::format_tag; |
266 | using namespace dnnl::impl::utils; |
267 | |
268 | const memory_desc_wrapper diff_src_d(&diff_src_md); |
269 | const memory_desc_wrapper weights_d(&weights_md); |
270 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
271 | |
272 | jcp.dsrc_dt = cd.diff_src_desc.data_type; |
273 | const bool is_bf16 = diff_dst_d.data_type() == bf16; |
274 | jcp.isa = (is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 : isa; |
275 | |
276 | if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core))) |
277 | return status::unimplemented; |
278 | |
279 | const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; |
280 | if (!with_groups) return status::unimplemented; |
281 | |
282 | const int ndims = diff_src_d.ndims(); |
283 | jcp.ngroups = weights_d.dims()[0]; |
284 | jcp.mb = diff_src_d.dims()[0]; |
285 | |
286 | jcp.oc = diff_dst_d.dims()[1]; |
287 | jcp.oc_without_padding = jcp.oc; |
288 | jcp.ic = diff_src_d.dims()[1]; |
289 | |
290 | jcp.ih = diff_src_d.dims()[2]; |
291 | jcp.iw = diff_src_d.dims()[3]; |
292 | jcp.oh = diff_dst_d.dims()[2]; |
293 | jcp.ow = diff_dst_d.dims()[3]; |
294 | |
295 | jcp.kh = weights_d.dims()[3]; |
296 | jcp.kw = weights_d.dims()[4]; |
297 | |
298 | jcp.t_pad = cd.padding[0][0]; |
299 | jcp.l_pad = cd.padding[0][1]; |
300 | |
301 | jcp.stride_h = cd.strides[0]; |
302 | jcp.stride_w = cd.strides[1]; |
303 | |
304 | jcp.dilate_h = cd.dilates[0]; |
305 | jcp.dilate_w = cd.dilates[1]; |
306 | |
307 | const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
308 | const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
309 | jcp.r_pad = calculate_end_padding( |
310 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
311 | jcp.b_pad = calculate_end_padding( |
312 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
313 | |
314 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
315 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
316 | |
317 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
318 | const auto dat_tag_blocked = isa == avx512_core ? nChw16c : nChw8c; |
319 | const auto wei_tag = isa == avx512_core ? Goihw16g : Goihw8g; |
320 | |
321 | auto curr_src_tag |
322 | = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); |
323 | auto curr_dst_tag |
324 | = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); |
325 | bool is_data_layout_nxc |
326 | = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
327 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_blocked; |
328 | |
329 | if (diff_src_md.format_kind == format_kind::any) { |
330 | CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag_blocked)); |
331 | jcp.src_tag = dat_tag_blocked; |
332 | } else if (curr_src_tag != dat_tag) |
333 | return status::unimplemented; |
334 | else |
335 | jcp.src_tag = dat_tag; |
336 | |
337 | if (diff_dst_md.format_kind == format_kind::any) { |
338 | CHECK(memory_desc_init_by_tag(diff_dst_md, dat_tag_blocked)); |
339 | jcp.dst_tag = dat_tag_blocked; |
340 | } else if (curr_dst_tag != dat_tag) |
341 | return status::unimplemented; |
342 | else |
343 | jcp.dst_tag = dat_tag; |
344 | |
345 | if (weights_d.format_kind() == format_kind::any) { |
346 | CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); |
347 | jcp.wei_tag = wei_tag; |
348 | } else { |
349 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
350 | } |
351 | |
352 | // No support for mixed types between SRC and DIFF_DST tensors |
353 | if (!everyone_is(dat_tag, jcp.src_tag, jcp.dst_tag) |
354 | || jcp.wei_tag != wei_tag) |
355 | return status::unimplemented; |
356 | |
357 | // note: sse41 uses 'ch_block = 8' where the value is derived |
358 | // from: 'simd_w_ * reg_repeats_ = 4 * 2' |
359 | jcp.ch_block = isa == avx512_core ? 16 : 8; |
360 | |
361 | bool ok_to_pad_channels = !is_data_layout_nxc && jcp.oc == jcp.ngroups |
362 | && jcp.ic == jcp.ngroups && one_of(isa, avx512_core, avx2); |
363 | if (ok_to_pad_channels) { |
364 | jcp.oc = rnd_up(jcp.oc, jcp.ch_block); |
365 | jcp.ic = rnd_up(jcp.oc, jcp.ch_block); |
366 | jcp.ngroups = rnd_up(jcp.ngroups, jcp.ch_block); |
367 | } |
368 | |
369 | bool args_ok = true && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups |
370 | && IMPLICATION(!is_data_layout_nxc, jcp.ngroups % jcp.ch_block == 0) |
371 | && jcp.dilate_h == 0 && jcp.dilate_w == 0 |
372 | && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 |
373 | && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1 |
374 | && jcp.ic <= diff_src_d.padded_dims()[1] |
375 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
376 | && jcp.ngroups <= weights_d.padded_dims()[0]; |
377 | if (!args_ok) return status::unimplemented; |
378 | |
379 | jcp.typesize_out = types::data_type_size(diff_src_d.data_type()); |
380 | jcp.typesize_in = types::data_type_size(diff_dst_d.data_type()); |
381 | |
382 | jcp.ur_w = is_bf16 ? (isa_has_bf16(jcp.isa) ? 6 : 4) |
383 | : isa == avx512_core ? 6 : isa == avx2 ? 4 : 3; |
384 | |
385 | jcp.loop_order = is_data_layout_nxc ? loop_nhwcg : loop_ngcw; |
386 | |
387 | jcp.ch_tail = jcp.ngroups % jcp.ch_block; |
388 | jcp.nb_ch = div_up(jcp.ic, jcp.ch_block); |
389 | jcp.nb_ch_blocking = isa == avx512_core ? 4 : isa == avx2 ? 3 : 2; |
390 | if (jcp.nb_ch < jcp.nb_ch_blocking) jcp.nb_ch_blocking = jcp.nb_ch; |
391 | |
392 | const size_t max_ch_off |
393 | = static_cast<size_t>(jcp.nb_ch_blocking - 1) * jcp.ch_block; |
394 | constexpr size_t max_ex_off |
395 | = isa == sse41 ? 4 : 0; // extra offset from repeats |
396 | const size_t sp_step = is_data_layout_nxc ? jcp.ngroups : jcp.ch_block; |
397 | |
398 | // check that input offsets fit into s32 |
399 | const size_t max_oc_off |
400 | = max_ch_off * (is_data_layout_nxc ? 1 : jcp.oh * jcp.ow); |
401 | const size_t max_inp_sp_off = static_cast<size_t>(jcp.ur_w - 1) * sp_step; |
402 | const size_t max_input_offset |
403 | = (max_oc_off + max_inp_sp_off + max_ex_off) * jcp.typesize_in; |
404 | if (max_input_offset > INT_MAX) return status::unimplemented; |
405 | |
406 | // check that output offset fit into s32 |
407 | const size_t max_ic_off |
408 | = max_ch_off * (is_data_layout_nxc ? 1 : jcp.ih * jcp.iw); |
409 | const size_t max_out_sp_off |
410 | = static_cast<size_t>(jcp.ur_w - 1) * jcp.stride_w * sp_step; |
411 | const size_t max_output_offset |
412 | = (max_ic_off + max_out_sp_off + max_ex_off) * jcp.typesize_out; |
413 | if (max_output_offset > INT_MAX) return status::unimplemented; |
414 | |
415 | return status::success; |
416 | } |
417 | |
418 | template <cpu_isa_t isa, data_type_t kernel_dt> |
419 | void jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_scratchpad( |
420 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
421 | UNUSED(scratchpad); |
422 | UNUSED(jcp); |
423 | } |
424 | |
425 | template <cpu_isa_t isa, data_type_t kernel_dt> |
426 | status_t jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::init_conf( |
427 | jit_conv_conf_t &jcp, const convolution_desc_t &cd, |
428 | memory_desc_t &src_md, memory_desc_t &diff_weights_md, |
429 | memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) { |
430 | using namespace dnnl::impl::format_tag; |
431 | using namespace dnnl::impl::utils; |
432 | |
433 | const memory_desc_wrapper src_d(&src_md); |
434 | const memory_desc_wrapper diff_weights_d(&diff_weights_md); |
435 | const memory_desc_wrapper diff_bias_d(&diff_bias_md); |
436 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
437 | |
438 | jcp.dwei_dt = cd.diff_weights_desc.data_type; |
439 | const int ndims = src_d.ndims(); |
440 | const bool is_bf16 = src_d.data_type() == data_type::bf16; |
441 | jcp.isa = (is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 : isa; |
442 | |
443 | if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core))) |
444 | return status::unimplemented; |
445 | |
446 | jcp.ngroups = diff_weights_d.dims()[0]; |
447 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
448 | jcp.oc_without_padding = diff_dst_d.dims()[1]; |
449 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
450 | |
451 | const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; |
452 | |
453 | jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic); |
454 | |
455 | if (!jcp.is_depthwise) return status::unimplemented; |
456 | |
457 | jcp.mb = src_d.dims()[0]; |
458 | |
459 | jcp.ih = src_d.dims()[2]; |
460 | jcp.iw = src_d.dims()[3]; |
461 | jcp.oh = diff_dst_d.dims()[2]; |
462 | jcp.ow = diff_dst_d.dims()[3]; |
463 | |
464 | jcp.kh = diff_weights_d.dims()[3]; |
465 | jcp.kw = diff_weights_d.dims()[4]; |
466 | |
467 | jcp.stride_h = cd.strides[0]; |
468 | jcp.stride_w = cd.strides[1]; |
469 | |
470 | jcp.t_pad = cd.padding[0][0]; |
471 | jcp.l_pad = cd.padding[0][1]; |
472 | |
473 | jcp.dilate_h = cd.dilates[0]; |
474 | jcp.dilate_w = cd.dilates[1]; |
475 | |
476 | jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; |
477 | |
478 | const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
479 | const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
480 | jcp.r_pad = nstl::max(0, |
481 | calculate_end_padding( |
482 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw)); |
483 | jcp.b_pad = nstl::max(0, |
484 | calculate_end_padding( |
485 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh)); |
486 | |
487 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
488 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
489 | |
490 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
491 | const auto dat_tag_blocked |
492 | = isa == avx512_core ? nChw16c : nChw8c; // dnnl_aBcd16b |
493 | const auto wei_tag = isa == avx512_core ? Goihw16g : Goihw8g; |
494 | auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); |
495 | auto curr_dst_tag |
496 | = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked); |
497 | |
498 | bool is_data_layout_nxc |
499 | = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
500 | |
501 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_blocked; |
502 | |
503 | if (src_md.format_kind == format_kind::any) { |
504 | CHECK(memory_desc_init_by_tag(src_md, dat_tag_blocked)); |
505 | jcp.src_tag = dat_tag_blocked; |
506 | } else if (curr_src_tag != dat_tag) |
507 | return status::unimplemented; |
508 | else |
509 | jcp.src_tag = dat_tag; |
510 | |
511 | if (diff_dst_md.format_kind == format_kind::any) { |
512 | CHECK(memory_desc_init_by_tag(diff_dst_md, dat_tag_blocked)); |
513 | jcp.dst_tag = dat_tag_blocked; |
514 | } else if (curr_dst_tag != dat_tag) |
515 | return status::unimplemented; |
516 | else |
517 | jcp.dst_tag = dat_tag; |
518 | |
519 | if (diff_weights_d.format_kind() == format_kind::any) { |
520 | CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); |
521 | jcp.wei_tag = wei_tag; |
522 | } else { |
523 | jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); |
524 | } |
525 | |
526 | // No support for mixed types between SRC and DIFF_DST tensors |
527 | if (!everyone_is(dat_tag, jcp.src_tag, jcp.dst_tag) |
528 | || jcp.wei_tag != wei_tag) |
529 | return status::unimplemented; |
530 | |
531 | if (jcp.with_bias) { |
532 | if (diff_bias_d.format_kind() == format_kind::any) |
533 | CHECK(memory_desc_init_by_tag(diff_bias_md, x)); |
534 | } |
535 | |
536 | jcp.ch_block = isa == avx512_core ? 16 : 8; |
537 | jcp.ch_tail = jcp.oc_without_padding % jcp.ch_block; |
538 | |
539 | // note: bf16 to be supported in the next commit |
540 | bool ok_to_pad_channels |
541 | = !is_data_layout_nxc && one_of(isa, avx512_core, avx2); |
542 | if (ok_to_pad_channels) { jcp.ngroups = rnd_up(jcp.ngroups, jcp.ch_block); } |
543 | |
544 | bool args_ok = true |
545 | && IMPLICATION(!is_data_layout_nxc, jcp.ngroups % jcp.ch_block == 0) |
546 | && jcp.dilate_h == 0 && jcp.dilate_w == 0 && jcp.kw <= 3 |
547 | && jcp.stride_w <= jcp.kw // no gaps in kernel |
548 | && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 |
549 | && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; |
550 | if (!args_ok) return status::unimplemented; |
551 | |
552 | jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); |
553 | |
554 | // Note: avx2 can't do masked_fma and would require extra Vmms |
555 | // for byte_load. |
556 | // TODO: enable 'is_fast_depthwise' for bf16 if it offers performance |
557 | // improvement. |
558 | jcp.is_fast_depthwise |
559 | = !is_bf16 && is_data_layout_nxc && one_of(isa, avx512_core, avx2); |
560 | constexpr int max_reg_idx = isa == avx512_core ? 31 : 15; |
561 | // Note: anything larger than 4 didn't show significant speedup |
562 | const int max_isa_unroll = jcp.is_fast_depthwise ? 4 : 1; |
563 | int max_ch_unroll = nstl::min(max_isa_unroll, max_reg_idx / (2 * jcp.kw)); |
564 | jcp.nb_ch_blocking = nstl::min(jcp.nb_ch, max_ch_unroll); |
565 | |
566 | /* kernel applicability check wrt boundaries |
567 | * the conditions are quite general across the kernels we have, |
568 | * but ideally the check should belong to a specific kernel... */ |
569 | const int max_hpad = (jcp.kh - 1 + 1) / 2; |
570 | const int max_wpad = (jcp.kw - 1 + 1) / 2; |
571 | const int min_ih = jcp.kh + nstl::modulo(-jcp.t_pad, jcp.stride_h); |
572 | const bool boundaries_ok = true && jcp.t_pad <= max_hpad |
573 | && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad |
574 | && jcp.r_pad <= max_wpad |
575 | // input must fully accommodate the filter |
576 | && jcp.ih >= min_ih |
577 | // non-unit padding must be a multiple of the stride |
578 | && IMPLICATION(jcp.t_pad > 1, jcp.t_pad % jcp.stride_h == 0) |
579 | && IMPLICATION(jcp.b_pad > 1, jcp.b_pad % jcp.stride_h == 0); |
580 | if (!boundaries_ok) return status::unimplemented; |
581 | |
582 | /* BF16: accumulation of output happens in f32, down-conversion to bf16 |
583 | * happens during the reduction phase. */ |
584 | jcp.typesize_out = sizeof(float); |
585 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
586 | jcp.bia_dt = jcp.with_bias ? cd.diff_bias_desc.data_type : data_type::undef; |
587 | |
588 | jcp.harness = is_data_layout_nxc ? harness_nxc : harness_mb_reduction; |
589 | |
590 | balance(jcp, nthreads); |
591 | |
592 | return status::success; |
593 | } |
594 | |
595 | template <cpu_isa_t isa, data_type_t kernel_dt> |
596 | void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::init_scratchpad( |
597 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
598 | using namespace dnnl::impl::memory_tracking::names; |
599 | |
600 | if (jcp.harness == harness_mb_reduction) { |
601 | /* Notes: if splitting thread work on 'mb', then a reduction has to take |
602 | * place. Hence, book a per-thread, local weights-buffer for the |
603 | * reduction */ |
604 | if (jcp.nthr_mb > 1) { |
605 | const size_t mb = jcp.dwei_dt == data_type::bf16 ? jcp.nthr_mb |
606 | : jcp.nthr_mb - 1; |
607 | const size_t wei_size |
608 | = static_cast<size_t>(jcp.ngroups) * jcp.kh * jcp.kw; |
609 | scratchpad.book<float>(key_conv_wei_reduction, wei_size * mb); |
610 | |
611 | if (jcp.with_bias) |
612 | scratchpad.book<float>(key_conv_bia_reduction, |
613 | static_cast<size_t>(jcp.ngroups) * (jcp.nthr_mb - 1)); |
614 | } else if (jcp.nthr_mb == 1 && jcp.dwei_dt == data_type::bf16) { |
615 | const size_t wei_size |
616 | = static_cast<size_t>(jcp.ngroups) * jcp.kh * jcp.kw; |
617 | scratchpad.book<float>(key_conv_wei_reduction, wei_size); |
618 | } |
619 | } else if (jcp.harness == harness_nxc) { |
620 | if (jcp.nthr > 1 || jcp.dwei_dt == data_type::bf16) { |
621 | assert(jcp.nthr > 0); // redundant check |
622 | const size_t buff_count |
623 | = jcp.dwei_dt == data_type::bf16 ? jcp.nthr : jcp.nthr - 1; |
624 | |
625 | // note: because of weights blocked format, buffer is padded |
626 | // across ch_block |
627 | const size_t wei_size = static_cast<size_t>(utils::rnd_up( |
628 | jcp.ngroups, jcp.ch_block)) |
629 | * jcp.kh * jcp.kw; |
630 | scratchpad.book<float>( |
631 | key_conv_wei_reduction, wei_size * buff_count); |
632 | |
633 | if (jcp.with_bias) { |
634 | scratchpad.book<float>( |
635 | key_conv_bia_reduction, jcp.ngroups * buff_count); |
636 | } |
637 | } |
638 | } |
639 | |
640 | if (jcp.bia_dt == data_type::bf16) |
641 | scratchpad.book<float>(key_conv_bias_bf16_convert_wsp, jcp.ngroups); |
642 | } |
643 | |
644 | template <cpu_isa_t isa, data_type_t kernel_dt> |
645 | void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::balance( |
646 | jit_conv_conf_t &jcp, int nthreads) { |
647 | jcp.nthr_oh = jcp.nthr_g = jcp.nthr_mb = 1; |
648 | if (jcp.harness == harness_mb_reduction) { |
649 | /* Basic-Heuristics for parallel strategy: |
650 | * 1) Tries to parallel on the number of Groups (g) where tasks are |
651 | * independent. Otherwise, |
652 | * 2) Tries to split the work across g and MiniBatch (mb). |
653 | * Parallelizing on mb requires computing a reduction for weights. |
654 | * |
655 | * NOTE: because of 'task partitioning' scheme, there will be unbalanced |
656 | * per-thread load when the number of threads is high (e.g. > 16). |
657 | */ |
658 | jcp.oh_blk_size = 15; |
659 | jcp.nthr_g = nstl::min(jcp.nb_ch, nthreads); |
660 | jcp.nthr_mb = nstl::min(nstl::max(1, nthreads / jcp.nthr_g), jcp.mb); |
661 | jcp.nthr = jcp.nthr_g * jcp.nthr_mb; |
662 | } else if (jcp.harness == harness_nxc) { |
663 | /* Allocate threads and partition space with regards to 'nb_ch', 'mb' |
664 | * and 'nb_oh' (derived from selecting 'oh_block') |
665 | * |
666 | * note: 'prioritize_threading == true' showed slightly greater |
667 | * performance, but there might be cases where the opposite holds true; |
668 | * code is left for future tuning. */ |
669 | partition_nthr_nxc(jcp, nthreads, true); |
670 | jcp.nthr = jcp.nthr_g * jcp.nthr_mb * jcp.nthr_oh; |
671 | } |
672 | } |
673 | |
674 | template <cpu_isa_t isa, data_type_t kernel_dt> |
675 | void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::partition_nthr_nxc( |
676 | jit_conv_conf_t &jcp, int nthreads, bool prioritize_threading) { |
677 | |
678 | /* Explore thread partitioning space across 'nb_ch', 'mb' and 'nb_oh' |
679 | * (determined by 'oh / oh_block'). Prioritize finding a |
680 | * partition where the most number of threads are used ('thr_eff'). |
681 | * |
682 | * Additionally, try to reduce work imbalance across threads |
683 | * (i.e. 'total_imbalance'). |
684 | */ |
685 | float best_thr_eff = 0.; // maximinze |
686 | float best_imbalance = 1.; // minimize |
687 | |
688 | // Performance-tuning variables - enable through 'getenv_int()' |
689 | // if necessary |
690 | const int env_max_nthr_g = nthreads; // DNNL_MAX_NTHR_G |
691 | const int env_max_nthr_mb = nthreads; // DNNL_MAX_NTHR_MB |
692 | const int env_max_nthr_oh = nthreads; // DNNL_MAX_NTHR_OH |
693 | const int env_min_oh_block = 1; // DNNL_MIN_OH_BLOCK |
694 | |
695 | const int ch_outer_blocks = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); |
696 | int max_g = nstl::min(env_max_nthr_g, nstl::min(ch_outer_blocks, nthreads)); |
697 | for (int g = max_g; g >= 1; --g) { |
698 | int cur_nthr_g = g; |
699 | auto div_nthr_g = nthreads / cur_nthr_g; |
700 | |
701 | int available_nthr_mb = div_nthr_g; |
702 | int max_mb = nstl::min( |
703 | env_max_nthr_mb, nstl::min(jcp.mb, available_nthr_mb)); |
704 | for (int mb = max_mb; mb >= 1; --mb) { |
705 | int cur_nthr_mb = mb; |
706 | auto div_nthr_mb = available_nthr_mb / cur_nthr_mb; |
707 | |
708 | // used to skip cases where efficiency can only worsen |
709 | bool prev_under_blocked = false; |
710 | |
711 | int available_nthr_oh = nstl::min( |
712 | jcp.oh, nstl::min(env_max_nthr_oh, div_nthr_mb)); |
713 | int max_oh_block = jcp.oh; |
714 | // Note: maybe it's worth exploring a heuristic to determine |
715 | // optimal_min(oh_block) |
716 | int min_oh_block |
717 | = nstl::max(1, nstl::min(jcp.oh, env_min_oh_block)); |
718 | for (int oh_block = max_oh_block; oh_block >= min_oh_block; |
719 | --oh_block) { |
720 | |
721 | // Calculate most efficient approximation for thread use and/or |
722 | // blocking: |
723 | int approx_g_block = utils::div_up(ch_outer_blocks, cur_nthr_g); |
724 | int approx_mb_block = utils::div_up(jcp.mb, cur_nthr_mb); |
725 | int approx_oh_block = utils::div_up(jcp.oh, oh_block); |
726 | |
727 | int cur_nthr_oh = nstl::min(available_nthr_oh, approx_oh_block); |
728 | |
729 | // calculate thread use efficiency |
730 | int total_nthr = cur_nthr_g * cur_nthr_mb * cur_nthr_oh; |
731 | float thr_eff = ((float)total_nthr) / nthreads; |
732 | assert(total_nthr <= nthreads); |
733 | |
734 | // efficiency can only worsen, skip |
735 | if (prev_under_blocked && available_nthr_oh < approx_oh_block) { |
736 | break; |
737 | } |
738 | |
739 | // calculate imbalance |
740 | float imbalance_g = ((float)std::abs(approx_g_block * cur_nthr_g |
741 | - ch_outer_blocks)) |
742 | / ch_outer_blocks; |
743 | float imbalance_mb |
744 | = ((float)std::abs( |
745 | approx_mb_block * cur_nthr_mb - jcp.mb)) |
746 | / jcp.mb; |
747 | float imbalance_oh |
748 | = ((float)std::abs(oh_block * cur_nthr_oh - jcp.oh)) |
749 | / jcp.oh; |
750 | float total_imbalance = imbalance_g * (jcp.mb * jcp.oh) |
751 | + imbalance_mb * (ch_outer_blocks * jcp.oh) |
752 | + imbalance_oh * (ch_outer_blocks * jcp.mb); |
753 | |
754 | /* 1) When 'prioritize_threading == true' |
755 | * First Condition: pick the blocking strategy that uses the |
756 | * most threads. |
757 | * Second Condition: if current blocking strategy uses at least |
758 | * the same amount of threads than the previous best (or more), |
759 | * chose if work imbalance is less than previous best. |
760 | * |
761 | * 2) Otherwise, ('prioritize_threading == false') |
762 | * First Condition: pick the blocking strategy that has the |
763 | * lowest thread work imbalance. |
764 | * Second Condition: if current blocking strategy has at least |
765 | * the same amount of work imbalance than the previous best(or |
766 | * lower), chose if it has more number of threads working. |
767 | * */ |
768 | const bool first_condition = prioritize_threading |
769 | ? best_thr_eff <= thr_eff |
770 | : best_imbalance >= total_imbalance; |
771 | const bool second_condition = prioritize_threading |
772 | ? best_thr_eff == thr_eff |
773 | && best_imbalance <= total_imbalance |
774 | : best_imbalance == total_imbalance |
775 | && best_thr_eff >= thr_eff; |
776 | if (first_condition) { |
777 | if (second_condition) { continue; } |
778 | jcp.nthr_g = cur_nthr_g; |
779 | jcp.nthr_mb = cur_nthr_mb; |
780 | jcp.nthr_oh = cur_nthr_oh; |
781 | jcp.oh_blk_size = oh_block; |
782 | best_imbalance = total_imbalance; |
783 | best_thr_eff = thr_eff; |
784 | } |
785 | prev_under_blocked = oh_block * cur_nthr_oh < jcp.oh; |
786 | } |
787 | } |
788 | } |
789 | } |
790 | |
791 | REG_AVX512_ISA(template struct jit_uni_dw_conv_fwd_kernel<avx512_core, bf16>); |
792 | REG_AVX512_ISA(template struct jit_uni_dw_conv_fwd_kernel<avx512_core, f32>); |
793 | REG_AVX2_ISA(template struct jit_uni_dw_conv_fwd_kernel<avx2, f32>); |
794 | REG_SSE41_ISA(template struct jit_uni_dw_conv_fwd_kernel<sse41, f32>); |
795 | |
796 | REG_AVX512_ISA( |
797 | template struct jit_uni_dw_conv_bwd_data_kernel<avx512_core, bf16>); |
798 | REG_AVX512_ISA( |
799 | template struct jit_uni_dw_conv_bwd_data_kernel<avx512_core, f32>); |
800 | REG_AVX2_ISA(template struct jit_uni_dw_conv_bwd_data_kernel<avx2, f32>); |
801 | REG_SSE41_ISA(template struct jit_uni_dw_conv_bwd_data_kernel<sse41, f32>); |
802 | |
803 | REG_AVX512_ISA( |
804 | template struct jit_uni_dw_conv_bwd_weights_kernel<avx512_core, bf16>); |
805 | REG_AVX512_ISA( |
806 | template struct jit_uni_dw_conv_bwd_weights_kernel<avx512_core, f32>); |
807 | REG_AVX2_ISA(template struct jit_uni_dw_conv_bwd_weights_kernel<avx2, f32>); |
808 | REG_SSE41_ISA(template struct jit_uni_dw_conv_bwd_weights_kernel<sse41, f32>); |
809 | } // namespace x64 |
810 | } // namespace cpu |
811 | } // namespace impl |
812 | } // namespace dnnl |
813 | |