1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX2_1X1_CONVOLUTION_HPP
18#define CPU_X64_JIT_AVX2_1X1_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/primitive_hashing.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_convolution_pd.hpp"
28#include "cpu/dw_convolution_utils.hpp"
29#include "cpu/platform.hpp"
30
31#include "cpu/x64/cpu_reducer.hpp"
32#include "cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp"
33#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
34#include "cpu/x64/jit_uni_dw_convolution.hpp"
35
36namespace dnnl {
37namespace impl {
38namespace cpu {
39namespace x64 {
40
41struct jit_avx2_1x1_convolution_fwd_t : public primitive_t {
42 // TODO: (Roma) Code duplication duplication! Remove with templates
43 // (maybe...)!
44 struct pd_t : public cpu_convolution_fwd_pd_t {
45 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
46 const typename pd_t::base_class *hint_fwd_pd)
47 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
48 , jcp_()
49 , rtus_()
50 , jcp_dw_(nullptr) {}
51
52 pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) {
53 if (copy(other) != status::success) is_initialized_ = false;
54 }
55
56 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", jcp_.isa, ""),
57 jit_avx2_1x1_convolution_fwd_t);
58
59 status_t init(engine_t *engine) {
60 bool ok = true && is_fwd()
61 && set_default_alg_kind(alg_kind::convolution_direct)
62 && expect_data_types(data_type::f32, data_type::f32,
63 data_type::f32, data_type::f32, data_type::f32)
64 && attr()->has_default_values(
65 primitive_attr_t::skip_mask_t::post_ops,
66 data_type::f32)
67 && !has_zero_dim_memory() && set_default_formats()
68 && attr_.set_default_formats(dst_md(0)) == status::success;
69 if (!ok) return status::unimplemented;
70
71 const convolution_desc_t *conv_d = desc();
72 const memory_desc_t *src_d = src_md();
73 rtus_prepare(this, conv_d, src_d, dst_md(), weights_md());
74
75 CHECK(jit_avx2_1x1_conv_kernel_f32::init_conf(
76 jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr()));
77 if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine));
78
79 auto scratchpad = scratchpad_registry().registrar();
80 jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
81
82 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
83
84 return status::success;
85 }
86
87 const memory_desc_t *dst_md(int index = 0) const override {
88 return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_;
89 }
90
91 const memory_desc_t *arg_md(int index = 0) const override {
92 if (jcp_.with_dw_conv) {
93 switch (index) {
94 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
95 return dw_conv_pd_->weights_md(0);
96 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
97 return dw_conv_pd_->weights_md(1);
98 default: break;
99 }
100 }
101 return convolution_fwd_pd_t::arg_md(index);
102 }
103
104 arg_usage_t arg_usage(int arg) const override {
105 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
106 return arg_usage_t::input;
107
108 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
109 && attr_post_op_dw_inputs() > 1)
110 return arg_usage_t::input;
111
112 return convolution_fwd_pd_t::arg_usage(arg);
113 }
114
115 jit_1x1_conv_conf_t jcp_;
116 reduce_to_unit_stride_t rtus_;
117 jit_conv_conf_t *jcp_dw_;
118 std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_;
119
120 protected:
121 template <cpu_isa_t isa>
122 using dw_pd_t = typename jit_uni_dw_convolution_fwd_t<isa,
123 data_type::f32>::pd_t;
124
125 bool set_default_formats() {
126 using namespace format_tag;
127
128 const memory_desc_wrapper src_d(&src_md_);
129 const memory_desc_wrapper dst_d(&dst_md_);
130
131 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
132 const auto dat_tag_nCx8c
133 = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
134 const auto curr_src_tag
135 = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
136 const auto curr_dst_tag
137 = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
138 const auto is_data_layout_nxc
139 = IMPLICATION(curr_src_tag != dat_tag_nxc,
140 src_d.format_kind() == format_kind::any)
141 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
142 dst_d.format_kind() == format_kind::any)
143 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
144 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
145 auto wei_tag = with_groups()
146 ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o)
147 : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o);
148
149 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
150 }
151
152 status_t copy(const pd_t &other) {
153 jcp_ = other.jcp_;
154 rtus_ = other.rtus_;
155 jcp_dw_ = nullptr;
156 if (other.dw_conv_pd_) {
157 dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>(
158 other.dw_conv_pd_->clone()));
159 if (!dw_conv_pd_) return status::out_of_memory;
160 if (jcp_.isa == avx2) {
161 jcp_dw_ = &(static_cast<dw_pd_t<avx2> *>(dw_conv_pd_.get())
162 ->jcp_);
163 } else { // sse41
164 jcp_dw_ = &(static_cast<dw_pd_t<sse41> *>(dw_conv_pd_.get())
165 ->jcp_);
166 }
167 }
168
169 return status::success;
170 }
171
172 status_t depthwise_po_init(engine_t *engine) {
173
174 using namespace memory_tracking;
175 auto &jcp_1x1 = jcp_;
176 primitive_attr_t attr_1x1(*attr());
177 if (!attr_1x1.is_initialized()) return status::out_of_memory;
178 jit_conv_conf_t *jcp_dw = nullptr;
179
180 const auto &src_md = dst_md_;
181 const memory_desc_wrapper src_d(src_md);
182 const auto nthr = dnnl_get_max_threads();
183 auto l2_cache = platform::get_per_core_cache_size(2) * nthr;
184
185 // Note: A robust fusion implementation would be to check if both
186 // 1x1 conv and dw conv that are considered here for fusion are
187 // optimal independently. This would require creating a new
188 // primitive_desc through primitive_iterator & check if they match.
189 // Due to concern that these creations and/or checks could be heavy,
190 // for 1x1: Check that no better ISA is available.
191 // for dw: Always fuse with same ISA.
192 // Caveat: May be a better dw conv exists.
193
194 bool ok = true && (!mayiuse(avx512_core))
195 && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1)
196 // TODO: Below may be further tuned.
197 && (l2_cache * 2 < src_d.size())
198 // load_grp_count check can be redundant due to l2 check
199 // above. Adding it explicitly as the current driver doesn't
200 // work if this condition fails.
201 && (jcp_1x1.load_grp_count < 2);
202 if (!ok) return status::unimplemented;
203
204 int dw_po_index
205 = attr_1x1.post_ops_.find(primitive_kind::convolution);
206
207 convolution_desc_t cd_dw;
208 primitive_attr_t attr_dw;
209
210 CHECK(get_depthwise_conv_desc(
211 cd_dw, src_md, attr_1x1, attr_dw, dw_po_index));
212
213 if (jcp_1x1.isa == avx2) {
214 std::unique_ptr<dw_pd_t<avx2>> fusable_pd(
215 new dw_pd_t<avx2>(&cd_dw, &attr_dw, nullptr));
216 CHECK(fusable_pd->init(engine));
217 jcp_dw = &(fusable_pd->jcp_);
218 dw_conv_pd_ = std::move(fusable_pd);
219 } else {
220 // Special case for this primitive, as we dont have dw<avx>.
221 // In this case fuse with sse41 depthwise conv
222 // NOTE: Currently dw f32 kernel is similar for all ISA and can
223 // be fused regardless of ISA if inter-connecting md_ matches.
224 std::unique_ptr<dw_pd_t<sse41>> fusable_pd(
225 new dw_pd_t<sse41>(&cd_dw, &attr_dw, nullptr));
226 CHECK(fusable_pd->init(engine));
227 jcp_dw = &(fusable_pd->jcp_);
228 dw_conv_pd_ = std::move(fusable_pd);
229 }
230
231 ok = true
232 && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0)))
233 && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0)
234 && IMPLICATION(
235 jcp_dw->ow_block, jcp_dw->ow_block == jcp_dw->ow);
236 if (!ok) return status::unimplemented;
237
238 assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any);
239 assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any);
240 assert(IMPLICATION(
241 dw_conv_pd_->weights_md(1)->data_type != data_type::undef,
242 dw_conv_pd_->weights_md(1)->format_kind
243 != format_kind::any));
244
245 jcp_dw->is_fused_conv = true;
246 // TODO: Support/experiment arbitary oc_work in dw conv.
247 // Until then we keep oc_work perfectly divisible.
248 while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0)
249 --jcp_1x1.nb_load_blocking;
250 jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking;
251
252 while (jcp_1x1.nb_load_blocking % jcp_dw->nb_ch_blocking != 0)
253 --jcp_dw->nb_ch_blocking;
254
255 jcp_dw->dw_conv_buffer_oc
256 = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block;
257
258 const auto dat_tag_nxc = utils::pick(ndims() - 3, format_tag::nwc,
259 format_tag::nhwc, format_tag::ndhwc);
260 const bool is_data_nxc = utils::everyone_is(
261 dat_tag_nxc, jcp_1x1.src_tag, jcp_1x1.dst_tag);
262 if (!is_data_nxc)
263 jcp_1x1.bcast_loop_output_step = jcp_1x1.ur * jcp_1x1.load_block
264 * jcp_1x1.typesize_out;
265
266 registrar_t scratchpad(scratchpad_registry_);
267 registrar_t dw_scratchpad(scratchpad, names::prefix_fusion);
268
269 size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw->kh * jcp_dw->iw
270 * jcp_dw->dw_conv_buffer_oc;
271 assert(dw_conv_buffer_size_);
272 dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
273 dw_conv_buffer_size_,
274 types::data_type_size(dw_conv_pd_->src_md()->data_type));
275
276 if (jcp_1x1.isa == avx2)
277 dw_conv_kernel_t<avx2>::init_scratchpad(dw_scratchpad, *jcp_dw);
278 else
279 dw_conv_kernel_t<sse41>::init_scratchpad(
280 dw_scratchpad, *jcp_dw);
281
282 return status::success;
283 }
284 };
285
286 template <cpu_isa_t isa, typename conv_t>
287 friend status_t init_rtus_driver(conv_t *self);
288
289 jit_avx2_1x1_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
290
291 status_t init(engine_t *engine) override {
292 CHECK(safe_ptr_assign(kernel_,
293 new jit_avx2_1x1_conv_kernel_f32(
294 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
295 CHECK(kernel_->create_kernel());
296 CHECK(init_rtus_driver<avx2>(this));
297 if (pd()->jcp_.with_dw_conv) {
298 auto &isa = pd()->jcp_.isa;
299
300 if (isa == avx2) {
301 CHECK(safe_ptr_assign(kernel_dw_avx2,
302 new dw_conv_kernel_t<avx2>(
303 *(pd()->jcp_dw_), *pd()->dst_md(0))));
304 CHECK(kernel_dw_avx2->create_kernel());
305 } else {
306 CHECK(safe_ptr_assign(kernel_dw_sse41,
307 new dw_conv_kernel_t<sse41>(
308 *(pd()->jcp_dw_), *pd()->dst_md(0))));
309 CHECK(kernel_dw_sse41->create_kernel());
310 }
311 }
312
313 return status::success;
314 }
315
316 typedef typename prec_traits<data_type::f32>::type data_t;
317
318 status_t execute(const exec_ctx_t &ctx) const override {
319 execute_forward(ctx);
320 return status::success;
321 }
322
323private:
324 void execute_forward(const exec_ctx_t &ctx) const;
325 void execute_forward_thr(const int ithr, const int nthr, const data_t *src,
326 const data_t *weights, const data_t *bias, const data_t *weights_dw,
327 const data_t *bias_dw, data_t *dst,
328 const memory_tracking::grantor_t &scratchpad,
329 const void *post_ops_binary_rhs_arg_vec,
330 const void *post_ops_binary_rhs_arg_vec_dw) const;
331 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
332
333 std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_;
334 std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_;
335
336 template <cpu_isa_t isa>
337 using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel<isa, data_type::f32>;
338
339 std::unique_ptr<dw_conv_kernel_t<avx2>> kernel_dw_avx2;
340 std::unique_ptr<dw_conv_kernel_t<sse41>> kernel_dw_sse41;
341};
342
343struct jit_avx2_1x1_convolution_bwd_data_t : public primitive_t {
344 struct pd_t : public cpu_convolution_bwd_data_pd_t {
345 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
346 const convolution_fwd_pd_t *hint_fwd_pd)
347 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd)
348 , jcp_()
349 , rtus_() {}
350
351 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
352 jit_avx2_1x1_convolution_bwd_data_t);
353
354 status_t init(engine_t *engine) {
355 bool ok = true && desc()->prop_kind == prop_kind::backward_data
356 && set_default_alg_kind(alg_kind::convolution_direct)
357 && expect_data_types(data_type::f32, data_type::f32,
358 data_type::undef, data_type::f32, data_type::f32)
359 && attr()->has_default_values() && !has_zero_dim_memory()
360 && set_default_formats();
361 if (!ok) return status::unimplemented;
362
363 const convolution_desc_t *conv_d = desc();
364 const memory_desc_t *diff_src_d = diff_src_md();
365 rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md());
366
367 status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
368 *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(),
369 *attr());
370 if (status != status::success) return status;
371
372 auto scratchpad = scratchpad_registry().registrar();
373 jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
374
375 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
376
377 return status::success;
378 }
379
380 jit_1x1_conv_conf_t jcp_;
381 reduce_to_unit_stride_t rtus_;
382
383 protected:
384 bool set_default_formats() {
385 using namespace format_tag;
386
387 const memory_desc_wrapper diff_src_d(&diff_src_md_);
388 const memory_desc_wrapper diff_dst_d(&diff_dst_md_);
389
390 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
391 const auto dat_tag_nCx8c
392 = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
393 const auto curr_src_tag
394 = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
395 const auto curr_dst_tag
396 = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
397 const auto is_data_layout_nxc
398 = IMPLICATION(curr_src_tag != dat_tag_nxc,
399 diff_src_d.format_kind() == format_kind::any)
400 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
401 diff_dst_d.format_kind() == format_kind::any)
402 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
403 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
404 auto wei_tag = with_groups()
405 ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i)
406 : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i);
407
408 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
409 }
410 };
411
412 template <cpu_isa_t isa, typename conv_t>
413 friend status_t init_rtus_driver(conv_t *self);
414
415 jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
416
417 typedef typename prec_traits<data_type::f32>::type data_t;
418
419 status_t init(engine_t *engine) override {
420 CHECK(safe_ptr_assign(kernel_,
421 new jit_avx2_1x1_conv_kernel_f32(
422 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
423 CHECK(kernel_->create_kernel());
424 CHECK(init_rtus_driver<avx2>(this));
425 return status::success;
426 }
427
428 status_t execute(const exec_ctx_t &ctx) const override {
429 execute_backward_data(ctx);
430 return status::success;
431 }
432
433private:
434 void execute_backward_data(const exec_ctx_t &ctx) const;
435 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
436
437 std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_;
438 std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_;
439};
440
441struct jit_avx2_1x1_convolution_bwd_weights_t : public primitive_t {
442 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
443 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
444 const convolution_fwd_pd_t *hint_fwd_pd)
445 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
446 , jcp_()
447 , rtus_() {}
448
449 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
450 jit_avx2_1x1_convolution_bwd_weights_t);
451
452 status_t init(engine_t *engine) {
453 bool ok = true && desc()->prop_kind == prop_kind::backward_weights
454 && set_default_alg_kind(alg_kind::convolution_direct)
455 && expect_data_types(data_type::f32, data_type::f32,
456 data_type::f32, data_type::f32, data_type::f32)
457 && attr()->has_default_values() && !has_zero_dim_memory()
458 && set_default_formats();
459 if (!ok) return status::unimplemented;
460
461 const convolution_desc_t *conv_d = desc();
462 const memory_desc_t *src_d = src_md();
463 rtus_prepare(this, conv_d, src_d, diff_dst_md(), diff_weights_md());
464
465 status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
466 *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(),
467 *attr());
468 if (status != status::success) return status;
469
470 init_balancers();
471
472 auto scratchpad = scratchpad_registry().registrar();
473 jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
474
475 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
476
477 auto reducer_bia_scratchpad = memory_tracking::registrar_t(
478 scratchpad, memory_tracking::names::prefix_reducer_bia);
479 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
480
481 auto reducer_wei_scratchpad = memory_tracking::registrar_t(
482 scratchpad, memory_tracking::names::prefix_reducer_wei);
483 reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
484
485 return status::success;
486 }
487
488 jit_1x1_conv_conf_t jcp_;
489 cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
490 cpu_reducer_2d_t<data_type::f32>::conf_t reducer_wei_conf_;
491 reduce_to_unit_stride_t rtus_;
492
493 protected:
494 bool set_default_formats() {
495 using namespace format_tag;
496
497 const memory_desc_wrapper src_d(&src_md_);
498 const memory_desc_wrapper diff_dst_d(&diff_dst_md_);
499
500 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
501 const auto dat_tag_nCx8c
502 = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
503 const auto curr_src_tag
504 = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
505 const auto curr_dst_tag
506 = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
507 const auto is_data_layout_nxc
508 = IMPLICATION(curr_src_tag != dat_tag_nxc,
509 src_d.format_kind() == format_kind::any)
510 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
511 diff_dst_d.format_kind() == format_kind::any)
512 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
513
514 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
515 auto wei_tag = with_groups()
516 ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o)
517 : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o);
518
519 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
520 }
521
522 private:
523 void init_balancers() {
524 const int ic_block = jcp_.bcast_block;
525 const int nb_ic = jcp_.nb_bcast;
526 const int nb_ic_blocking = jcp_.nb_bcast_blocking;
527 const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking);
528
529 const int oc_block = jcp_.load_block;
530 const int nb_oc = jcp_.nb_load;
531 const int nb_oc_blocking = jcp_.nb_load_blocking;
532 const int load_work = utils::div_up(nb_oc, nb_oc_blocking);
533
534 const int job_size
535 = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block;
536 const int njobs_x = bcast_work;
537 const int njobs_y = jcp_.ngroups * load_work;
538
539 const int max_threads = dnnl_get_max_threads();
540 const size_t max_buffer_size = (size_t)max_threads * job_size * 8;
541
542 if (with_bias()) {
543 reducer_bia_conf_.init(reduce_balancer_t(max_threads, oc_block,
544 jcp_.ngroups * nb_oc, jcp_.mb, max_buffer_size, true));
545 }
546
547 reducer_wei_conf_.init(
548 reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x,
549 jcp_.mb * jcp_.nb_reduce, max_buffer_size, true),
550 job_size / nb_oc_blocking, nb_oc_blocking, ic_block,
551 nb_ic * ic_block * oc_block, nb_oc);
552 }
553 };
554
555 template <cpu_isa_t isa, typename conv_t>
556 friend status_t init_rtus_driver(conv_t *self);
557
558 jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd)
559 : primitive_t(apd) {}
560
561 typedef typename prec_traits<data_type::f32>::type data_t;
562
563 status_t init(engine_t *engine) override;
564
565 status_t execute(const exec_ctx_t &ctx) const override {
566 execute_backward_weights(ctx);
567 return status::success;
568 }
569
570private:
571 void execute_backward_weights(const exec_ctx_t &ctx) const;
572 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
573
574 std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_;
575 std::unique_ptr<cpu_reducer_2d_t<data_type::f32>> reducer_weights_;
576 std::unique_ptr<cpu_reducer_t<data_type::f32>> reducer_bias_;
577 std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_;
578};
579
580} // namespace x64
581} // namespace cpu
582} // namespace impl
583} // namespace dnnl
584
585#endif
586