1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_COMMON_1X1_CONVOLUTION_HPP |
18 | #define CPU_X64_JIT_AVX512_COMMON_1X1_CONVOLUTION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/primitive_hashing.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | #include "cpu/cpu_convolution_pd.hpp" |
28 | #include "cpu/dw_convolution_utils.hpp" |
29 | #include "cpu/platform.hpp" |
30 | |
31 | #include "cpu/x64/cpu_reducer.hpp" |
32 | #include "cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp" |
33 | #include "cpu/x64/jit_transpose_utils.hpp" |
34 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
35 | #include "cpu/x64/jit_uni_dw_convolution.hpp" |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | namespace cpu { |
40 | namespace x64 { |
41 | |
42 | template <impl::data_type_t src_type, impl::data_type_t wei_type = src_type, |
43 | impl::data_type_t dst_type = src_type> |
44 | struct jit_avx512_common_1x1_convolution_fwd_t : public primitive_t { |
45 | struct pd_t : public cpu_convolution_fwd_pd_t { |
46 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
47 | const typename pd_t::base_class *hint_fwd_pd) |
48 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) |
49 | , jcp_() |
50 | , rtus_() {} |
51 | |
52 | pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { |
53 | if (copy(other) != status::success) is_initialized_ = false; |
54 | } |
55 | |
56 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:" , avx512_core, "" ), |
57 | jit_avx512_common_1x1_convolution_fwd_t); |
58 | |
59 | status_t init(engine_t *engine) { |
60 | using namespace utils; |
61 | bool ok = true && is_fwd() |
62 | && set_default_alg_kind(alg_kind::convolution_direct) |
63 | && expect_data_types(src_type, wei_type, dst_type, dst_type, |
64 | data_type::undef) |
65 | && attr()->has_default_values( |
66 | primitive_attr_t::skip_mask_t::post_ops, dst_type) |
67 | && !has_zero_dim_memory() && set_default_formats() |
68 | && attr_.set_default_formats(dst_md(0)) == status::success; |
69 | if (!ok) return status::unimplemented; |
70 | |
71 | const convolution_desc_t *conv_d = desc(); |
72 | const memory_desc_t *src_d = src_md(); |
73 | rtus_prepare(this, conv_d, src_d, dst_md(), weights_md()); |
74 | |
75 | CHECK(jit_avx512_common_1x1_conv_kernel::init_conf(jcp_, *conv_d, |
76 | *src_d, *weights_md(), *dst_md(), *attr(), |
77 | dnnl_get_max_threads(), rtus_.reduce_src_)); |
78 | if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); |
79 | |
80 | auto scratchpad = scratchpad_registry().registrar(); |
81 | jit_avx512_common_1x1_conv_kernel::init_scratchpad( |
82 | scratchpad, jcp_); |
83 | |
84 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
85 | |
86 | return status::success; |
87 | } |
88 | |
89 | const memory_desc_t *dst_md(int index = 0) const override { |
90 | return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_; |
91 | } |
92 | |
93 | const memory_desc_t *arg_md(int index = 0) const override { |
94 | if (jcp_.with_dw_conv) { |
95 | switch (index) { |
96 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: |
97 | return dw_conv_pd_->weights_md(0); |
98 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: |
99 | return dw_conv_pd_->weights_md(1); |
100 | default: break; |
101 | } |
102 | } |
103 | return convolution_fwd_pd_t::arg_md(index); |
104 | } |
105 | |
106 | arg_usage_t arg_usage(int arg) const override { |
107 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) |
108 | return arg_usage_t::input; |
109 | |
110 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) |
111 | && attr_post_op_dw_inputs() > 1) |
112 | return arg_usage_t::input; |
113 | |
114 | return convolution_fwd_pd_t::arg_usage(arg); |
115 | } |
116 | |
117 | jit_1x1_conv_conf_t jcp_; |
118 | reduce_to_unit_stride_t rtus_; |
119 | using dw_pd_t = jit_avx512_common_dw_convolution_fwd_t::pd_t; |
120 | std::unique_ptr<dw_pd_t> dw_conv_pd_; |
121 | |
122 | protected: |
123 | bool set_default_formats() { |
124 | using namespace format_tag; |
125 | |
126 | const memory_desc_wrapper src_d(&src_md_); |
127 | const memory_desc_wrapper dst_d(&dst_md_); |
128 | |
129 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
130 | const auto dat_tag_nCx16c |
131 | = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); |
132 | const auto curr_src_tag |
133 | = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
134 | const auto curr_dst_tag |
135 | = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
136 | const auto is_data_layout_nxc |
137 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
138 | src_d.format_kind() == format_kind::any) |
139 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
140 | dst_d.format_kind() == format_kind::any) |
141 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
142 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
143 | auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), |
144 | OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o, |
145 | gOIdhw16i16o); |
146 | |
147 | return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
148 | } |
149 | |
150 | status_t copy(const pd_t &other) { |
151 | jcp_ = other.jcp_; |
152 | rtus_ = other.rtus_; |
153 | if (other.dw_conv_pd_) { |
154 | dw_conv_pd_.reset(other.dw_conv_pd_->clone()); |
155 | if (!dw_conv_pd_) return status::out_of_memory; |
156 | } |
157 | return status::success; |
158 | } |
159 | |
160 | status_t depthwise_po_init(engine_t *engine) { |
161 | |
162 | using namespace memory_tracking; |
163 | auto &jcp_1x1 = jcp_; |
164 | primitive_attr_t attr_1x1(*attr()); |
165 | if (!attr_1x1.is_initialized()) return status::out_of_memory; |
166 | const auto &src_md = dst_md_; |
167 | const memory_desc_wrapper src_d(src_md); |
168 | const auto nthr = dnnl_get_max_threads(); |
169 | auto l2_cache = platform::get_per_core_cache_size(2) * nthr; |
170 | |
171 | // Note: A robust fusion implementation would be to check if both |
172 | // 1x1 conv and dw conv that are considered here for fusion are |
173 | // optimal independently. This would require creating a new |
174 | // primitive_desc through primitive_iterator & check if they match. |
175 | // Due to concern that these creations and/or checks could be heavy, |
176 | // for 1x1: Check that no better ISA is available. |
177 | // for dw: Always fuse with same ISA. |
178 | // Caveat: May be a better dw conv exists. |
179 | |
180 | // TODO: Add a check if better ISA exists following above note. |
181 | bool ok = true |
182 | && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) |
183 | // TODO: Below may be further tuned. |
184 | && (l2_cache * 2 < src_d.size()) |
185 | // load_grp_count check can be redundant due to l2 check |
186 | // above. Adding it explicitly as the current driver doesn't |
187 | // work if this condition fails. |
188 | && (jcp_1x1.load_grp_count < 2); |
189 | if (!ok) return status::unimplemented; |
190 | |
191 | int dw_po_index |
192 | = attr_1x1.post_ops_.find(primitive_kind::convolution); |
193 | convolution_desc_t cd_dw; |
194 | primitive_attr_t attr_dw; |
195 | CHECK(get_depthwise_conv_desc( |
196 | cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); |
197 | |
198 | CHECK(safe_ptr_assign( |
199 | dw_conv_pd_, new dw_pd_t(&cd_dw, &attr_dw, nullptr))); |
200 | CHECK(dw_conv_pd_->init(engine)); |
201 | auto &jcp_dw = dw_conv_pd_->jcp_; |
202 | |
203 | ok = true |
204 | && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) |
205 | && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) |
206 | && IMPLICATION( |
207 | jcp_dw.ow_block, jcp_dw.ow_block == jcp_dw.ow); |
208 | if (!ok) return status::unimplemented; |
209 | |
210 | assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); |
211 | assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); |
212 | assert(IMPLICATION( |
213 | dw_conv_pd_->weights_md(1)->data_type != data_type::undef, |
214 | dw_conv_pd_->weights_md(1)->format_kind |
215 | != format_kind::any)); |
216 | |
217 | jcp_dw.is_fused_conv = true; |
218 | // TODO: Support/experiment arbitary oc_work in dw conv. |
219 | // Until then we keep oc_work perfectly divisible. |
220 | while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) |
221 | --jcp_1x1.nb_load_blocking; |
222 | jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; |
223 | |
224 | while (jcp_1x1.nb_load_blocking % jcp_dw.nb_ch_blocking != 0) |
225 | --jcp_dw.nb_ch_blocking; |
226 | |
227 | jcp_dw.dw_conv_buffer_oc |
228 | = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; |
229 | |
230 | const auto dat_tag_nxc = utils::pick(ndims() - 3, format_tag::nwc, |
231 | format_tag::nhwc, format_tag::ndhwc); |
232 | const bool is_data_nxc = utils::everyone_is( |
233 | dat_tag_nxc, jcp_1x1.src_tag, jcp_1x1.dst_tag); |
234 | if (!is_data_nxc) |
235 | jcp_1x1.bcast_loop_output_step = jcp_1x1.ur * jcp_1x1.load_block |
236 | * jcp_1x1.typesize_out; |
237 | |
238 | registrar_t scratchpad(scratchpad_registry_); |
239 | registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); |
240 | |
241 | size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw.kh * jcp_dw.iw |
242 | * jcp_dw.dw_conv_buffer_oc; |
243 | assert(dw_conv_buffer_size_); |
244 | dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, |
245 | dw_conv_buffer_size_, |
246 | types::data_type_size(dw_conv_pd_->src_md()->data_type)); |
247 | |
248 | jit_uni_dw_conv_fwd_kernel<avx512_core, |
249 | data_type::f32>::init_scratchpad(dw_scratchpad, jcp_dw); |
250 | |
251 | return status::success; |
252 | } |
253 | }; |
254 | |
255 | template <cpu_isa_t isa, typename conv_t> |
256 | friend status_t init_rtus_driver(conv_t *self); |
257 | |
258 | jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd) |
259 | : primitive_t(apd) {} |
260 | |
261 | typedef typename prec_traits<src_type>::type src_data_t; |
262 | typedef typename prec_traits<wei_type>::type wei_data_t; |
263 | typedef typename prec_traits<dst_type>::type dst_data_t; |
264 | |
265 | status_t init(engine_t *engine) override { |
266 | CHECK(safe_ptr_assign(kernel_, |
267 | new jit_avx512_common_1x1_conv_kernel( |
268 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
269 | CHECK(kernel_->create_kernel()); |
270 | |
271 | if (pd()->jcp_.with_dw_conv) { |
272 | CHECK(safe_ptr_assign(kernel_dw_, |
273 | new dw_conv_kernel_t( |
274 | pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0)))); |
275 | CHECK(kernel_dw_->create_kernel()); |
276 | } |
277 | |
278 | CHECK(init_rtus_driver<avx512_core>(this)); |
279 | return status::success; |
280 | } |
281 | |
282 | status_t execute(const exec_ctx_t &ctx) const override { |
283 | execute_forward(ctx); |
284 | return status::success; |
285 | } |
286 | |
287 | private: |
288 | void execute_forward(const exec_ctx_t &ctx) const; |
289 | void execute_forward_thr(const int ithr, const int nthr, |
290 | const src_data_t *src, const wei_data_t *weights, |
291 | const dst_data_t *bias, const wei_data_t *weights_dw, |
292 | const dst_data_t *bias_dw, dst_data_t *dst, |
293 | const memory_tracking::grantor_t &scratchpad, |
294 | const void *post_ops_binary_rhs_arg_vec, |
295 | const void *post_ops_binary_rhs_arg_vec_dw) const; |
296 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
297 | |
298 | std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_; |
299 | std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_; |
300 | using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel_f32<avx512_core>; |
301 | std::unique_ptr<dw_conv_kernel_t> kernel_dw_; |
302 | }; |
303 | |
304 | using jit_avx512_common_1x1_convolution_fwd_f32_t |
305 | = jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>; |
306 | |
307 | template <impl::data_type_t diff_dst_type, |
308 | impl::data_type_t wei_type = diff_dst_type, |
309 | impl::data_type_t diff_src_type = diff_dst_type> |
310 | struct jit_avx512_common_1x1_convolution_bwd_data_t : public primitive_t { |
311 | struct pd_t : public cpu_convolution_bwd_data_pd_t { |
312 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
313 | const convolution_fwd_pd_t *hint_fwd_pd) |
314 | : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) |
315 | , jcp_() |
316 | , rtus_() {} |
317 | |
318 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:" , avx512_core, "" ), |
319 | jit_avx512_common_1x1_convolution_bwd_data_t); |
320 | |
321 | status_t init(engine_t *engine) { |
322 | bool ok = true && desc()->prop_kind == prop_kind::backward_data |
323 | && set_default_alg_kind(alg_kind::convolution_direct) |
324 | && expect_data_types(diff_src_type, wei_type, |
325 | data_type::undef, diff_dst_type, data_type::undef) |
326 | && attr()->has_default_values() && !has_zero_dim_memory() |
327 | && set_default_formats(); |
328 | if (!ok) return status::unimplemented; |
329 | |
330 | const convolution_desc_t *conv_d = desc(); |
331 | const memory_desc_t *diff_src_d = diff_src_md(); |
332 | rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md()); |
333 | |
334 | status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_, |
335 | *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), |
336 | *attr(), dnnl_get_max_threads(), rtus_.reduce_src_); |
337 | if (status != status::success) return status; |
338 | |
339 | auto scratchpad = scratchpad_registry().registrar(); |
340 | jit_avx512_common_1x1_conv_kernel::init_scratchpad( |
341 | scratchpad, jcp_); |
342 | |
343 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
344 | |
345 | return status::success; |
346 | } |
347 | |
348 | // TODO (Roma): structs conf header cleanup |
349 | jit_1x1_conv_conf_t jcp_; |
350 | reduce_to_unit_stride_t rtus_; |
351 | |
352 | protected: |
353 | bool set_default_formats() { |
354 | using namespace format_tag; |
355 | |
356 | const memory_desc_wrapper diff_src_d(&diff_src_md_); |
357 | const memory_desc_wrapper diff_dst_d(&diff_dst_md_); |
358 | |
359 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
360 | const auto dat_tag_nCx16c |
361 | = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); |
362 | const auto curr_src_tag = diff_src_d.matches_one_of_tag( |
363 | dat_tag_nxc, dat_tag_nCx16c); |
364 | const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( |
365 | dat_tag_nxc, dat_tag_nCx16c); |
366 | const auto is_data_layout_nxc |
367 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
368 | diff_src_d.format_kind() == format_kind::any) |
369 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
370 | diff_dst_d.format_kind() == format_kind::any) |
371 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
372 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
373 | auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), |
374 | IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i, IOdhw16o16i, |
375 | gIOdhw16o16i); |
376 | |
377 | return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
378 | } |
379 | }; |
380 | |
381 | template <cpu_isa_t isa, typename conv_t> |
382 | friend status_t init_rtus_driver(conv_t *self); |
383 | |
384 | jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd) |
385 | : primitive_t(apd) {} |
386 | |
387 | typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; |
388 | typedef typename prec_traits<wei_type>::type wei_data_t; |
389 | typedef typename prec_traits<diff_src_type>::type diff_src_data_t; |
390 | |
391 | status_t init(engine_t *engine) override { |
392 | CHECK(safe_ptr_assign(kernel_, |
393 | new jit_avx512_common_1x1_conv_kernel( |
394 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
395 | CHECK(kernel_->create_kernel()); |
396 | CHECK(init_rtus_driver<avx512_core>(this)); |
397 | return status::success; |
398 | } |
399 | |
400 | status_t execute(const exec_ctx_t &ctx) const override { |
401 | execute_backward_data(ctx); |
402 | return status::success; |
403 | } |
404 | |
405 | private: |
406 | void execute_backward_data(const exec_ctx_t &ctx) const; |
407 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
408 | |
409 | std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_; |
410 | std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_; |
411 | }; |
412 | |
413 | using jit_avx512_common_1x1_convolution_bwd_data_f32_t |
414 | = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>; |
415 | |
416 | struct jit_avx512_common_1x1_convolution_bwd_weights_t : public primitive_t { |
417 | struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
418 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
419 | const convolution_fwd_pd_t *hint_fwd_pd) |
420 | : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) |
421 | , jcp_() |
422 | , rtus_() {} |
423 | |
424 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:" , avx512_core, "" ), |
425 | jit_avx512_common_1x1_convolution_bwd_weights_t); |
426 | |
427 | status_t init(engine_t *engine) { |
428 | bool ok = true && desc()->prop_kind == prop_kind::backward_weights |
429 | && set_default_alg_kind(alg_kind::convolution_direct) |
430 | && expect_data_types(data_type::f32, data_type::f32, |
431 | data_type::f32, data_type::f32, data_type::f32) |
432 | && attr()->has_default_values() && !has_zero_dim_memory() |
433 | && set_default_formats(); |
434 | if (!ok) return status::unimplemented; |
435 | |
436 | const convolution_desc_t *conv_d = desc(); |
437 | const memory_desc_t *src_d = src_md(); |
438 | rtus_prepare(this, conv_d, src_d, diff_dst_md(), diff_weights_md()); |
439 | |
440 | status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_, |
441 | *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), |
442 | *attr(), dnnl_get_max_threads(), rtus_.reduce_src_); |
443 | if (status != status::success) return status; |
444 | |
445 | init_balancers(); |
446 | |
447 | auto scratchpad = scratchpad_registry().registrar(); |
448 | jit_avx512_common_1x1_conv_kernel::init_scratchpad( |
449 | scratchpad, jcp_); |
450 | |
451 | auto reducer_bia_scratchpad = memory_tracking::registrar_t( |
452 | scratchpad, memory_tracking::names::prefix_reducer_bia); |
453 | reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); |
454 | |
455 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
456 | |
457 | return status::success; |
458 | } |
459 | |
460 | // TODO (Roma): structs conf header cleanup |
461 | jit_1x1_conv_conf_t jcp_; |
462 | cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_; |
463 | reduce_to_unit_stride_t rtus_; |
464 | |
465 | protected: |
466 | bool set_default_formats() { |
467 | using namespace format_tag; |
468 | |
469 | const memory_desc_wrapper src_d(&src_md_); |
470 | const memory_desc_wrapper diff_dst_d(&diff_dst_md_); |
471 | |
472 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
473 | const auto dat_tag_nCx16c |
474 | = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); |
475 | const auto curr_src_tag |
476 | = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
477 | const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( |
478 | dat_tag_nxc, dat_tag_nCx16c); |
479 | const auto is_data_layout_nxc |
480 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
481 | src_d.format_kind() == format_kind::any) |
482 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
483 | diff_dst_d.format_kind() == format_kind::any) |
484 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
485 | |
486 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
487 | auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), |
488 | OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o, |
489 | gOIdhw16i16o); |
490 | |
491 | return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
492 | } |
493 | |
494 | private: |
495 | void init_balancers() { |
496 | const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; |
497 | if (with_bias()) { |
498 | reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, |
499 | jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, jcp_.mb, |
500 | max_buffer_size, true)); |
501 | } |
502 | } |
503 | }; |
504 | |
505 | template <cpu_isa_t isa, typename conv_t> |
506 | friend status_t init_rtus_driver(conv_t *self); |
507 | |
508 | jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd) |
509 | : primitive_t(apd) {} |
510 | |
511 | typedef typename prec_traits<data_type::f32>::type data_t; |
512 | |
513 | status_t init(engine_t *engine) override; |
514 | |
515 | status_t execute(const exec_ctx_t &ctx) const override { |
516 | execute_backward_weights(ctx); |
517 | return status::success; |
518 | } |
519 | |
520 | private: |
521 | void execute_backward_weights(const exec_ctx_t &ctx) const; |
522 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
523 | |
524 | std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_; |
525 | std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; |
526 | std::unique_ptr<cpu_reducer_t<data_type::f32>> reducer_bias_; |
527 | std::unique_ptr<jit_transpose4x16_src> trans_kernel_; |
528 | std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_; |
529 | }; |
530 | |
531 | } // namespace x64 |
532 | } // namespace cpu |
533 | } // namespace impl |
534 | } // namespace dnnl |
535 | |
536 | #endif |
537 | |