1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/cpu_convolution_pd.hpp"
18
19#include "cpu/x64/jit_uni_dw_conv_kernel_utils.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace cpu {
24namespace x64 {
25
26using namespace data_type;
27
28template <cpu_isa_t isa, data_type_t kernel_dt>
29status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
30 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
31 memory_desc_t &src_md, memory_desc_t &weights_md,
32 memory_desc_t &bias_md, memory_desc_t &dst_md, primitive_attr_t &attr) {
33
34 using namespace dnnl::impl::format_tag;
35 using namespace dnnl::impl::utils;
36
37 const memory_desc_wrapper src_d(&src_md);
38 const memory_desc_wrapper weights_d(&weights_md);
39 const memory_desc_wrapper dst_d(&dst_md);
40 const memory_desc_wrapper bias_d(&bias_md);
41
42 const int ndims = src_d.ndims();
43 // Currently this kernel only supports 2D convolutions.
44 if (ndims != 4) return status::unimplemented;
45
46 jcp.prop_kind = cd.prop_kind;
47
48 const auto blocked_tag = isa == avx512_core ? nChw16c : nChw8c;
49 const auto wei_tag = isa == avx512_core ? Goihw16g : Goihw8g;
50 const auto nxc_tag = nhwc;
51 const auto def_tag
52 = (mayiuse(avx512_core)
53 && jcp.prop_kind == prop_kind::forward_inference)
54 ? nxc_tag
55 : blocked_tag;
56
57 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
58
59 if (src_d.format_kind() == format_kind::any) {
60 CHECK(memory_desc_init_by_tag(src_md, def_tag));
61 jcp.src_tag = def_tag;
62 } else {
63 jcp.src_tag = src_d.matches_one_of_tag(blocked_tag, nxc_tag);
64 }
65
66 if (weights_d.format_kind() == format_kind::any) {
67 CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
68 jcp.wei_tag = wei_tag;
69 } else {
70 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
71 }
72
73 if (dst_d.format_kind() == format_kind::any) {
74 CHECK(memory_desc_init_by_tag(dst_md, def_tag));
75 jcp.dst_tag = def_tag;
76 } else {
77 jcp.dst_tag = dst_d.matches_one_of_tag(blocked_tag, nxc_tag);
78 }
79
80 if (jcp.with_bias) {
81 if (bias_d.format_kind() == format_kind::any)
82 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
83 }
84
85 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
86 const auto data_tag = jcp.src_tag;
87 const bool is_data_layout_nxc = data_tag == nxc_tag;
88
89 const bool is_bf16 = src_d.data_type() == data_type::bf16;
90
91 jcp.dst_dt = cd.dst_desc.data_type;
92 jcp.isa = (is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 : isa;
93
94 if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core)))
95 return status::unimplemented;
96
97 const int simd_w = isa == avx512_core ? 16 : 8;
98
99 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
100 if (!with_groups) return status::unimplemented;
101
102 jcp.ngroups = weights_d.dims()[0];
103 jcp.mb = src_d.dims()[0];
104
105 jcp.oc = dst_d.dims()[1];
106 jcp.oc_without_padding = jcp.oc;
107 jcp.ic = src_d.dims()[1];
108
109 jcp.ih = src_d.dims()[2];
110 jcp.iw = src_d.dims()[3];
111 jcp.oh = dst_d.dims()[2];
112 jcp.ow = dst_d.dims()[3];
113
114 jcp.kh = weights_d.dims()[3];
115 jcp.kw = weights_d.dims()[4];
116
117 jcp.t_pad = cd.padding[0][0];
118 jcp.l_pad = cd.padding[0][1];
119
120 jcp.stride_h = cd.strides[0];
121 jcp.stride_w = cd.strides[1];
122
123 jcp.dilate_h = cd.dilates[0];
124 jcp.dilate_w = cd.dilates[1];
125
126 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
127 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
128 jcp.r_pad = calculate_end_padding(
129 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
130 jcp.b_pad = calculate_end_padding(
131 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
132 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
133 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad
134 || ext_kh <= jcp.b_pad;
135 if (kernel_outside_src) return status::unimplemented;
136
137 jcp.typesize_out = types::data_type_size(dst_d.data_type());
138 jcp.typesize_in = types::data_type_size(src_d.data_type());
139
140 jcp.loop_order = loop_ngcw;
141
142 jcp.ur_w = is_bf16 ? (isa_has_bf16(jcp.isa) ? 6 : 4)
143 : isa == avx512_core ? 6 : isa == avx2 ? 4 : 3;
144 jcp.ur_w = nstl::min(jcp.ur_w, jcp.ow);
145
146 jcp.ch_block = simd_w;
147 jcp.nb_ch = div_up(jcp.oc, jcp.ch_block);
148 jcp.nb_ch_blocking = isa == avx512_core ? 4 : isa == avx2 ? 3 : 2;
149 if (jcp.nb_ch < jcp.nb_ch_blocking) jcp.nb_ch_blocking = jcp.nb_ch;
150
151 if (is_data_layout_nxc) {
152 jcp.loop_order = loop_nhwcg;
153 const int resrc_depthwise_ur_w = (31 - jcp.kw + jcp.stride_w)
154 / (jcp.nb_ch_blocking + jcp.stride_w);
155 jcp.is_resrc_depthwise = (!is_bf16) && isa == avx512_core
156 && jcp.stride_w < jcp.kw && jcp.kw <= 5 && jcp.dilate_w == 0
157 && resrc_depthwise_ur_w >= 2;
158 if (jcp.is_resrc_depthwise) {
159 jcp.ur_w = nstl::min(jcp.ow, resrc_depthwise_ur_w);
160 }
161 bool cache_aliasing
162 = (jcp.ngroups * jcp.iw * jcp.typesize_in) % 1024 == 0;
163 if (cache_aliasing) {
164 // currently only tuned for mobilenet-v1 shapes
165 const int limit = jcp.ow > 7 ? 7 : 4;
166 jcp.ur_w = nstl::min(jcp.ur_w, limit);
167 }
168 } else {
169 const size_t max_ch_off
170 = static_cast<size_t>(jcp.nb_ch_blocking - 1) * jcp.ch_block;
171 constexpr size_t max_ex_off
172 = isa == sse41 ? 4 : 0; // extra offset from repeats
173
174 // check that input offsets fit into s32
175 const size_t max_ic_off = max_ch_off * jcp.ih * jcp.iw;
176 const size_t max_iw_idx
177 = static_cast<size_t>(jcp.ur_w - 1) * jcp.stride_w
178 + (ext_kw - 1);
179 const size_t max_iw_off = max_iw_idx * jcp.ch_block;
180 const size_t max_input_offset
181 = (max_ic_off + max_iw_off + max_ex_off) * jcp.typesize_in;
182 if (max_input_offset > INT_MAX) return status::unimplemented;
183
184 // check that output offsets fit into s32
185 const size_t max_oc_off = max_ch_off * jcp.oh * jcp.ow;
186 const size_t max_ow_off
187 = static_cast<size_t>(jcp.ur_w - 1) * jcp.ch_block;
188 const size_t max_output_offset
189 = (max_oc_off + max_ow_off + max_ex_off) * jcp.typesize_out;
190 if (max_output_offset > INT_MAX) return status::unimplemented;
191 }
192
193 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
194
195 int r_pad_no_tail = nstl::max(0,
196 calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
197 jcp.stride_w, ext_kw));
198 if (jcp.l_pad > jcp.ur_w || r_pad_no_tail > jcp.ur_w)
199 return status::unimplemented;
200
201 CHECK(attr.set_default_formats(&dst_md));
202
203 const auto &post_ops = attr.post_ops_;
204
205 jcp.with_sum = post_ops.find(primitive_kind::sum) != -1;
206 const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
207 jcp.with_eltwise = eltwise_ind != -1;
208 if (jcp.with_eltwise) jcp.eltwise = post_ops.entry_[eltwise_ind].eltwise;
209 const int binary_ind = post_ops.find(primitive_kind::binary);
210 jcp.with_binary = binary_ind != -1;
211 if (jcp.with_binary) {
212 using namespace dnnl::impl::cpu::binary_injector_utils;
213 std::tie(jcp.with_binary_per_oc_bcast, jcp.with_binary_no_bcast)
214 = bcast_strategies_present_tup(post_ops.entry_, dst_d,
215 broadcasting_strategy_t::per_oc,
216 broadcasting_strategy_t::no_broadcast);
217 }
218
219 jcp.post_ops = post_ops;
220
221 using namespace injector;
222 static constexpr bool sum_at_pos_0_only = true;
223 static constexpr bool sum_requires_scale_one = true;
224 const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum},
225 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one});
226 if (!post_ops_ok_) return status::unimplemented;
227
228 const bool ok_to_pad_channels = true && !is_data_layout_nxc
229 && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups
230 && one_of(isa, avx512_core, avx2);
231 if (ok_to_pad_channels) {
232 jcp.oc = rnd_up(jcp.oc, simd_w);
233 jcp.ic = rnd_up(jcp.oc, simd_w);
234 jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
235 }
236
237 const bool args_ok = true && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups
238 && IMPLICATION(!is_data_layout_nxc, jcp.ngroups % simd_w == 0)
239 && jcp.wei_tag == wei_tag && data_tag != format_tag::undef
240 && jcp.ic <= src_d.padded_dims()[1]
241 && jcp.oc <= dst_d.padded_dims()[1]
242 && jcp.ngroups <= weights_d.padded_dims()[0];
243 if (!args_ok) return status::unimplemented;
244
245 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
246
247 return status::success;
248}
249
250template <cpu_isa_t isa, data_type_t kernel_dt>
251void jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_scratchpad(
252 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
253 using namespace dnnl::impl::memory_tracking::names;
254 if (jcp.bia_dt == data_type::bf16)
255 scratchpad.book<float>(key_conv_bias_bf16_convert_wsp, jcp.oc);
256 else if (jcp.with_bias && jcp.oc_without_padding != jcp.oc)
257 scratchpad.book<float>(key_conv_padded_bias, jcp.oc);
258}
259
260template <cpu_isa_t isa, data_type_t kernel_dt>
261status_t jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_conf(
262 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
263 memory_desc_t &diff_src_md, memory_desc_t &weights_md,
264 memory_desc_t &diff_dst_md) {
265 using namespace dnnl::impl::format_tag;
266 using namespace dnnl::impl::utils;
267
268 const memory_desc_wrapper diff_src_d(&diff_src_md);
269 const memory_desc_wrapper weights_d(&weights_md);
270 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
271
272 jcp.dsrc_dt = cd.diff_src_desc.data_type;
273 const bool is_bf16 = diff_dst_d.data_type() == bf16;
274 jcp.isa = (is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 : isa;
275
276 if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core)))
277 return status::unimplemented;
278
279 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
280 if (!with_groups) return status::unimplemented;
281
282 const int ndims = diff_src_d.ndims();
283 jcp.ngroups = weights_d.dims()[0];
284 jcp.mb = diff_src_d.dims()[0];
285
286 jcp.oc = diff_dst_d.dims()[1];
287 jcp.oc_without_padding = jcp.oc;
288 jcp.ic = diff_src_d.dims()[1];
289
290 jcp.ih = diff_src_d.dims()[2];
291 jcp.iw = diff_src_d.dims()[3];
292 jcp.oh = diff_dst_d.dims()[2];
293 jcp.ow = diff_dst_d.dims()[3];
294
295 jcp.kh = weights_d.dims()[3];
296 jcp.kw = weights_d.dims()[4];
297
298 jcp.t_pad = cd.padding[0][0];
299 jcp.l_pad = cd.padding[0][1];
300
301 jcp.stride_h = cd.strides[0];
302 jcp.stride_w = cd.strides[1];
303
304 jcp.dilate_h = cd.dilates[0];
305 jcp.dilate_w = cd.dilates[1];
306
307 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
308 const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
309 jcp.r_pad = calculate_end_padding(
310 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
311 jcp.b_pad = calculate_end_padding(
312 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
313
314 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
315 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
316
317 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
318 const auto dat_tag_blocked = isa == avx512_core ? nChw16c : nChw8c;
319 const auto wei_tag = isa == avx512_core ? Goihw16g : Goihw8g;
320
321 auto curr_src_tag
322 = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked);
323 auto curr_dst_tag
324 = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked);
325 bool is_data_layout_nxc
326 = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag);
327 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_blocked;
328
329 if (diff_src_md.format_kind == format_kind::any) {
330 CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag_blocked));
331 jcp.src_tag = dat_tag_blocked;
332 } else if (curr_src_tag != dat_tag)
333 return status::unimplemented;
334 else
335 jcp.src_tag = dat_tag;
336
337 if (diff_dst_md.format_kind == format_kind::any) {
338 CHECK(memory_desc_init_by_tag(diff_dst_md, dat_tag_blocked));
339 jcp.dst_tag = dat_tag_blocked;
340 } else if (curr_dst_tag != dat_tag)
341 return status::unimplemented;
342 else
343 jcp.dst_tag = dat_tag;
344
345 if (weights_d.format_kind() == format_kind::any) {
346 CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
347 jcp.wei_tag = wei_tag;
348 } else {
349 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
350 }
351
352 // No support for mixed types between SRC and DIFF_DST tensors
353 if (!everyone_is(dat_tag, jcp.src_tag, jcp.dst_tag)
354 || jcp.wei_tag != wei_tag)
355 return status::unimplemented;
356
357 // note: sse41 uses 'ch_block = 8' where the value is derived
358 // from: 'simd_w_ * reg_repeats_ = 4 * 2'
359 jcp.ch_block = isa == avx512_core ? 16 : 8;
360
361 bool ok_to_pad_channels = !is_data_layout_nxc && jcp.oc == jcp.ngroups
362 && jcp.ic == jcp.ngroups && one_of(isa, avx512_core, avx2);
363 if (ok_to_pad_channels) {
364 jcp.oc = rnd_up(jcp.oc, jcp.ch_block);
365 jcp.ic = rnd_up(jcp.oc, jcp.ch_block);
366 jcp.ngroups = rnd_up(jcp.ngroups, jcp.ch_block);
367 }
368
369 bool args_ok = true && jcp.oc == jcp.ngroups && jcp.ic == jcp.ngroups
370 && IMPLICATION(!is_data_layout_nxc, jcp.ngroups % jcp.ch_block == 0)
371 && jcp.dilate_h == 0 && jcp.dilate_w == 0
372 && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
373 && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1
374 && jcp.ic <= diff_src_d.padded_dims()[1]
375 && jcp.oc <= diff_dst_d.padded_dims()[1]
376 && jcp.ngroups <= weights_d.padded_dims()[0];
377 if (!args_ok) return status::unimplemented;
378
379 jcp.typesize_out = types::data_type_size(diff_src_d.data_type());
380 jcp.typesize_in = types::data_type_size(diff_dst_d.data_type());
381
382 jcp.ur_w = is_bf16 ? (isa_has_bf16(jcp.isa) ? 6 : 4)
383 : isa == avx512_core ? 6 : isa == avx2 ? 4 : 3;
384
385 jcp.loop_order = is_data_layout_nxc ? loop_nhwcg : loop_ngcw;
386
387 jcp.ch_tail = jcp.ngroups % jcp.ch_block;
388 jcp.nb_ch = div_up(jcp.ic, jcp.ch_block);
389 jcp.nb_ch_blocking = isa == avx512_core ? 4 : isa == avx2 ? 3 : 2;
390 if (jcp.nb_ch < jcp.nb_ch_blocking) jcp.nb_ch_blocking = jcp.nb_ch;
391
392 const size_t max_ch_off
393 = static_cast<size_t>(jcp.nb_ch_blocking - 1) * jcp.ch_block;
394 constexpr size_t max_ex_off
395 = isa == sse41 ? 4 : 0; // extra offset from repeats
396 const size_t sp_step = is_data_layout_nxc ? jcp.ngroups : jcp.ch_block;
397
398 // check that input offsets fit into s32
399 const size_t max_oc_off
400 = max_ch_off * (is_data_layout_nxc ? 1 : jcp.oh * jcp.ow);
401 const size_t max_inp_sp_off = static_cast<size_t>(jcp.ur_w - 1) * sp_step;
402 const size_t max_input_offset
403 = (max_oc_off + max_inp_sp_off + max_ex_off) * jcp.typesize_in;
404 if (max_input_offset > INT_MAX) return status::unimplemented;
405
406 // check that output offset fit into s32
407 const size_t max_ic_off
408 = max_ch_off * (is_data_layout_nxc ? 1 : jcp.ih * jcp.iw);
409 const size_t max_out_sp_off
410 = static_cast<size_t>(jcp.ur_w - 1) * jcp.stride_w * sp_step;
411 const size_t max_output_offset
412 = (max_ic_off + max_out_sp_off + max_ex_off) * jcp.typesize_out;
413 if (max_output_offset > INT_MAX) return status::unimplemented;
414
415 return status::success;
416}
417
418template <cpu_isa_t isa, data_type_t kernel_dt>
419void jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_scratchpad(
420 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
421 UNUSED(scratchpad);
422 UNUSED(jcp);
423}
424
425template <cpu_isa_t isa, data_type_t kernel_dt>
426status_t jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::init_conf(
427 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
428 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
429 memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
430 using namespace dnnl::impl::format_tag;
431 using namespace dnnl::impl::utils;
432
433 const memory_desc_wrapper src_d(&src_md);
434 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
435 const memory_desc_wrapper diff_bias_d(&diff_bias_md);
436 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
437
438 jcp.dwei_dt = cd.diff_weights_desc.data_type;
439 const int ndims = src_d.ndims();
440 const bool is_bf16 = src_d.data_type() == data_type::bf16;
441 jcp.isa = (is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 : isa;
442
443 if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core)))
444 return status::unimplemented;
445
446 jcp.ngroups = diff_weights_d.dims()[0];
447 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
448 jcp.oc_without_padding = diff_dst_d.dims()[1];
449 jcp.ic = src_d.dims()[1] / jcp.ngroups;
450
451 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
452
453 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic);
454
455 if (!jcp.is_depthwise) return status::unimplemented;
456
457 jcp.mb = src_d.dims()[0];
458
459 jcp.ih = src_d.dims()[2];
460 jcp.iw = src_d.dims()[3];
461 jcp.oh = diff_dst_d.dims()[2];
462 jcp.ow = diff_dst_d.dims()[3];
463
464 jcp.kh = diff_weights_d.dims()[3];
465 jcp.kw = diff_weights_d.dims()[4];
466
467 jcp.stride_h = cd.strides[0];
468 jcp.stride_w = cd.strides[1];
469
470 jcp.t_pad = cd.padding[0][0];
471 jcp.l_pad = cd.padding[0][1];
472
473 jcp.dilate_h = cd.dilates[0];
474 jcp.dilate_w = cd.dilates[1];
475
476 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
477
478 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
479 const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
480 jcp.r_pad = nstl::max(0,
481 calculate_end_padding(
482 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
483 jcp.b_pad = nstl::max(0,
484 calculate_end_padding(
485 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
486
487 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
488 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
489
490 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
491 const auto dat_tag_blocked
492 = isa == avx512_core ? nChw16c : nChw8c; // dnnl_aBcd16b
493 const auto wei_tag = isa == avx512_core ? Goihw16g : Goihw8g;
494 auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked);
495 auto curr_dst_tag
496 = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked);
497
498 bool is_data_layout_nxc
499 = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag);
500
501 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_blocked;
502
503 if (src_md.format_kind == format_kind::any) {
504 CHECK(memory_desc_init_by_tag(src_md, dat_tag_blocked));
505 jcp.src_tag = dat_tag_blocked;
506 } else if (curr_src_tag != dat_tag)
507 return status::unimplemented;
508 else
509 jcp.src_tag = dat_tag;
510
511 if (diff_dst_md.format_kind == format_kind::any) {
512 CHECK(memory_desc_init_by_tag(diff_dst_md, dat_tag_blocked));
513 jcp.dst_tag = dat_tag_blocked;
514 } else if (curr_dst_tag != dat_tag)
515 return status::unimplemented;
516 else
517 jcp.dst_tag = dat_tag;
518
519 if (diff_weights_d.format_kind() == format_kind::any) {
520 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
521 jcp.wei_tag = wei_tag;
522 } else {
523 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
524 }
525
526 // No support for mixed types between SRC and DIFF_DST tensors
527 if (!everyone_is(dat_tag, jcp.src_tag, jcp.dst_tag)
528 || jcp.wei_tag != wei_tag)
529 return status::unimplemented;
530
531 if (jcp.with_bias) {
532 if (diff_bias_d.format_kind() == format_kind::any)
533 CHECK(memory_desc_init_by_tag(diff_bias_md, x));
534 }
535
536 jcp.ch_block = isa == avx512_core ? 16 : 8;
537 jcp.ch_tail = jcp.oc_without_padding % jcp.ch_block;
538
539 // note: bf16 to be supported in the next commit
540 bool ok_to_pad_channels
541 = !is_data_layout_nxc && one_of(isa, avx512_core, avx2);
542 if (ok_to_pad_channels) { jcp.ngroups = rnd_up(jcp.ngroups, jcp.ch_block); }
543
544 bool args_ok = true
545 && IMPLICATION(!is_data_layout_nxc, jcp.ngroups % jcp.ch_block == 0)
546 && jcp.dilate_h == 0 && jcp.dilate_w == 0 && jcp.kw <= 3
547 && jcp.stride_w <= jcp.kw // no gaps in kernel
548 && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
549 && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
550 if (!args_ok) return status::unimplemented;
551
552 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
553
554 // Note: avx2 can't do masked_fma and would require extra Vmms
555 // for byte_load.
556 // TODO: enable 'is_fast_depthwise' for bf16 if it offers performance
557 // improvement.
558 jcp.is_fast_depthwise
559 = !is_bf16 && is_data_layout_nxc && one_of(isa, avx512_core, avx2);
560 constexpr int max_reg_idx = isa == avx512_core ? 31 : 15;
561 // Note: anything larger than 4 didn't show significant speedup
562 const int max_isa_unroll = jcp.is_fast_depthwise ? 4 : 1;
563 int max_ch_unroll = nstl::min(max_isa_unroll, max_reg_idx / (2 * jcp.kw));
564 jcp.nb_ch_blocking = nstl::min(jcp.nb_ch, max_ch_unroll);
565
566 /* kernel applicability check wrt boundaries
567 * the conditions are quite general across the kernels we have,
568 * but ideally the check should belong to a specific kernel... */
569 const int max_hpad = (jcp.kh - 1 + 1) / 2;
570 const int max_wpad = (jcp.kw - 1 + 1) / 2;
571 const int min_ih = jcp.kh + nstl::modulo(-jcp.t_pad, jcp.stride_h);
572 const bool boundaries_ok = true && jcp.t_pad <= max_hpad
573 && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad
574 && jcp.r_pad <= max_wpad
575 // input must fully accommodate the filter
576 && jcp.ih >= min_ih
577 // non-unit padding must be a multiple of the stride
578 && IMPLICATION(jcp.t_pad > 1, jcp.t_pad % jcp.stride_h == 0)
579 && IMPLICATION(jcp.b_pad > 1, jcp.b_pad % jcp.stride_h == 0);
580 if (!boundaries_ok) return status::unimplemented;
581
582 /* BF16: accumulation of output happens in f32, down-conversion to bf16
583 * happens during the reduction phase. */
584 jcp.typesize_out = sizeof(float);
585 jcp.typesize_in = types::data_type_size(src_d.data_type());
586 jcp.bia_dt = jcp.with_bias ? cd.diff_bias_desc.data_type : data_type::undef;
587
588 jcp.harness = is_data_layout_nxc ? harness_nxc : harness_mb_reduction;
589
590 balance(jcp, nthreads);
591
592 return status::success;
593}
594
595template <cpu_isa_t isa, data_type_t kernel_dt>
596void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::init_scratchpad(
597 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
598 using namespace dnnl::impl::memory_tracking::names;
599
600 if (jcp.harness == harness_mb_reduction) {
601 /* Notes: if splitting thread work on 'mb', then a reduction has to take
602 * place. Hence, book a per-thread, local weights-buffer for the
603 * reduction */
604 if (jcp.nthr_mb > 1) {
605 const size_t mb = jcp.dwei_dt == data_type::bf16 ? jcp.nthr_mb
606 : jcp.nthr_mb - 1;
607 const size_t wei_size
608 = static_cast<size_t>(jcp.ngroups) * jcp.kh * jcp.kw;
609 scratchpad.book<float>(key_conv_wei_reduction, wei_size * mb);
610
611 if (jcp.with_bias)
612 scratchpad.book<float>(key_conv_bia_reduction,
613 static_cast<size_t>(jcp.ngroups) * (jcp.nthr_mb - 1));
614 } else if (jcp.nthr_mb == 1 && jcp.dwei_dt == data_type::bf16) {
615 const size_t wei_size
616 = static_cast<size_t>(jcp.ngroups) * jcp.kh * jcp.kw;
617 scratchpad.book<float>(key_conv_wei_reduction, wei_size);
618 }
619 } else if (jcp.harness == harness_nxc) {
620 if (jcp.nthr > 1 || jcp.dwei_dt == data_type::bf16) {
621 assert(jcp.nthr > 0); // redundant check
622 const size_t buff_count
623 = jcp.dwei_dt == data_type::bf16 ? jcp.nthr : jcp.nthr - 1;
624
625 // note: because of weights blocked format, buffer is padded
626 // across ch_block
627 const size_t wei_size = static_cast<size_t>(utils::rnd_up(
628 jcp.ngroups, jcp.ch_block))
629 * jcp.kh * jcp.kw;
630 scratchpad.book<float>(
631 key_conv_wei_reduction, wei_size * buff_count);
632
633 if (jcp.with_bias) {
634 scratchpad.book<float>(
635 key_conv_bia_reduction, jcp.ngroups * buff_count);
636 }
637 }
638 }
639
640 if (jcp.bia_dt == data_type::bf16)
641 scratchpad.book<float>(key_conv_bias_bf16_convert_wsp, jcp.ngroups);
642}
643
644template <cpu_isa_t isa, data_type_t kernel_dt>
645void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::balance(
646 jit_conv_conf_t &jcp, int nthreads) {
647 jcp.nthr_oh = jcp.nthr_g = jcp.nthr_mb = 1;
648 if (jcp.harness == harness_mb_reduction) {
649 /* Basic-Heuristics for parallel strategy:
650 * 1) Tries to parallel on the number of Groups (g) where tasks are
651 * independent. Otherwise,
652 * 2) Tries to split the work across g and MiniBatch (mb).
653 * Parallelizing on mb requires computing a reduction for weights.
654 *
655 * NOTE: because of 'task partitioning' scheme, there will be unbalanced
656 * per-thread load when the number of threads is high (e.g. > 16).
657 */
658 jcp.oh_blk_size = 15;
659 jcp.nthr_g = nstl::min(jcp.nb_ch, nthreads);
660 jcp.nthr_mb = nstl::min(nstl::max(1, nthreads / jcp.nthr_g), jcp.mb);
661 jcp.nthr = jcp.nthr_g * jcp.nthr_mb;
662 } else if (jcp.harness == harness_nxc) {
663 /* Allocate threads and partition space with regards to 'nb_ch', 'mb'
664 * and 'nb_oh' (derived from selecting 'oh_block')
665 *
666 * note: 'prioritize_threading == true' showed slightly greater
667 * performance, but there might be cases where the opposite holds true;
668 * code is left for future tuning. */
669 partition_nthr_nxc(jcp, nthreads, true);
670 jcp.nthr = jcp.nthr_g * jcp.nthr_mb * jcp.nthr_oh;
671 }
672}
673
674template <cpu_isa_t isa, data_type_t kernel_dt>
675void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::partition_nthr_nxc(
676 jit_conv_conf_t &jcp, int nthreads, bool prioritize_threading) {
677
678 /* Explore thread partitioning space across 'nb_ch', 'mb' and 'nb_oh'
679 * (determined by 'oh / oh_block'). Prioritize finding a
680 * partition where the most number of threads are used ('thr_eff').
681 *
682 * Additionally, try to reduce work imbalance across threads
683 * (i.e. 'total_imbalance').
684 */
685 float best_thr_eff = 0.; // maximinze
686 float best_imbalance = 1.; // minimize
687
688 // Performance-tuning variables - enable through 'getenv_int()'
689 // if necessary
690 const int env_max_nthr_g = nthreads; // DNNL_MAX_NTHR_G
691 const int env_max_nthr_mb = nthreads; // DNNL_MAX_NTHR_MB
692 const int env_max_nthr_oh = nthreads; // DNNL_MAX_NTHR_OH
693 const int env_min_oh_block = 1; // DNNL_MIN_OH_BLOCK
694
695 const int ch_outer_blocks = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
696 int max_g = nstl::min(env_max_nthr_g, nstl::min(ch_outer_blocks, nthreads));
697 for (int g = max_g; g >= 1; --g) {
698 int cur_nthr_g = g;
699 auto div_nthr_g = nthreads / cur_nthr_g;
700
701 int available_nthr_mb = div_nthr_g;
702 int max_mb = nstl::min(
703 env_max_nthr_mb, nstl::min(jcp.mb, available_nthr_mb));
704 for (int mb = max_mb; mb >= 1; --mb) {
705 int cur_nthr_mb = mb;
706 auto div_nthr_mb = available_nthr_mb / cur_nthr_mb;
707
708 // used to skip cases where efficiency can only worsen
709 bool prev_under_blocked = false;
710
711 int available_nthr_oh = nstl::min(
712 jcp.oh, nstl::min(env_max_nthr_oh, div_nthr_mb));
713 int max_oh_block = jcp.oh;
714 // Note: maybe it's worth exploring a heuristic to determine
715 // optimal_min(oh_block)
716 int min_oh_block
717 = nstl::max(1, nstl::min(jcp.oh, env_min_oh_block));
718 for (int oh_block = max_oh_block; oh_block >= min_oh_block;
719 --oh_block) {
720
721 // Calculate most efficient approximation for thread use and/or
722 // blocking:
723 int approx_g_block = utils::div_up(ch_outer_blocks, cur_nthr_g);
724 int approx_mb_block = utils::div_up(jcp.mb, cur_nthr_mb);
725 int approx_oh_block = utils::div_up(jcp.oh, oh_block);
726
727 int cur_nthr_oh = nstl::min(available_nthr_oh, approx_oh_block);
728
729 // calculate thread use efficiency
730 int total_nthr = cur_nthr_g * cur_nthr_mb * cur_nthr_oh;
731 float thr_eff = ((float)total_nthr) / nthreads;
732 assert(total_nthr <= nthreads);
733
734 // efficiency can only worsen, skip
735 if (prev_under_blocked && available_nthr_oh < approx_oh_block) {
736 break;
737 }
738
739 // calculate imbalance
740 float imbalance_g = ((float)std::abs(approx_g_block * cur_nthr_g
741 - ch_outer_blocks))
742 / ch_outer_blocks;
743 float imbalance_mb
744 = ((float)std::abs(
745 approx_mb_block * cur_nthr_mb - jcp.mb))
746 / jcp.mb;
747 float imbalance_oh
748 = ((float)std::abs(oh_block * cur_nthr_oh - jcp.oh))
749 / jcp.oh;
750 float total_imbalance = imbalance_g * (jcp.mb * jcp.oh)
751 + imbalance_mb * (ch_outer_blocks * jcp.oh)
752 + imbalance_oh * (ch_outer_blocks * jcp.mb);
753
754 /* 1) When 'prioritize_threading == true'
755 * First Condition: pick the blocking strategy that uses the
756 * most threads.
757 * Second Condition: if current blocking strategy uses at least
758 * the same amount of threads than the previous best (or more),
759 * chose if work imbalance is less than previous best.
760 *
761 * 2) Otherwise, ('prioritize_threading == false')
762 * First Condition: pick the blocking strategy that has the
763 * lowest thread work imbalance.
764 * Second Condition: if current blocking strategy has at least
765 * the same amount of work imbalance than the previous best(or
766 * lower), chose if it has more number of threads working.
767 * */
768 const bool first_condition = prioritize_threading
769 ? best_thr_eff <= thr_eff
770 : best_imbalance >= total_imbalance;
771 const bool second_condition = prioritize_threading
772 ? best_thr_eff == thr_eff
773 && best_imbalance <= total_imbalance
774 : best_imbalance == total_imbalance
775 && best_thr_eff >= thr_eff;
776 if (first_condition) {
777 if (second_condition) { continue; }
778 jcp.nthr_g = cur_nthr_g;
779 jcp.nthr_mb = cur_nthr_mb;
780 jcp.nthr_oh = cur_nthr_oh;
781 jcp.oh_blk_size = oh_block;
782 best_imbalance = total_imbalance;
783 best_thr_eff = thr_eff;
784 }
785 prev_under_blocked = oh_block * cur_nthr_oh < jcp.oh;
786 }
787 }
788 }
789}
790
791REG_AVX512_ISA(template struct jit_uni_dw_conv_fwd_kernel<avx512_core, bf16>);
792REG_AVX512_ISA(template struct jit_uni_dw_conv_fwd_kernel<avx512_core, f32>);
793REG_AVX2_ISA(template struct jit_uni_dw_conv_fwd_kernel<avx2, f32>);
794REG_SSE41_ISA(template struct jit_uni_dw_conv_fwd_kernel<sse41, f32>);
795
796REG_AVX512_ISA(
797 template struct jit_uni_dw_conv_bwd_data_kernel<avx512_core, bf16>);
798REG_AVX512_ISA(
799 template struct jit_uni_dw_conv_bwd_data_kernel<avx512_core, f32>);
800REG_AVX2_ISA(template struct jit_uni_dw_conv_bwd_data_kernel<avx2, f32>);
801REG_SSE41_ISA(template struct jit_uni_dw_conv_bwd_data_kernel<sse41, f32>);
802
803REG_AVX512_ISA(
804 template struct jit_uni_dw_conv_bwd_weights_kernel<avx512_core, bf16>);
805REG_AVX512_ISA(
806 template struct jit_uni_dw_conv_bwd_weights_kernel<avx512_core, f32>);
807REG_AVX2_ISA(template struct jit_uni_dw_conv_bwd_weights_kernel<avx2, f32>);
808REG_SSE41_ISA(template struct jit_uni_dw_conv_bwd_weights_kernel<sse41, f32>);
809} // namespace x64
810} // namespace cpu
811} // namespace impl
812} // namespace dnnl
813