1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_BF16_1X1_CONVOLUTION_HPP
18#define CPU_X64_JIT_AVX512_CORE_BF16_1X1_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/primitive.hpp"
23#include "common/primitive_hashing.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_convolution_pd.hpp"
27#include "cpu/cpu_engine.hpp"
28#include "cpu/dw_convolution_utils.hpp"
29#include "cpu/platform.hpp"
30
31#include "cpu/x64/cpu_reducer.hpp"
32#include "cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp"
33#include "cpu/x64/jit_transpose_utils.hpp"
34#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
35#include "cpu/x64/jit_uni_dw_convolution.hpp"
36
37namespace dnnl {
38namespace impl {
39namespace cpu {
40namespace x64 {
41
42template <impl::data_type_t dst_type>
43struct jit_avx512_core_bf16_1x1_convolution_fwd_t : public primitive_t {
44 struct pd_t : public cpu_convolution_fwd_pd_t {
45 using dw_conv_pd_type = cpu_convolution_fwd_pd_t;
46 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
47 const typename pd_t::base_class *hint_fwd_pd)
48 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
49 , jcp_()
50 , rtus_()
51 , jcp_dw_(nullptr) {}
52
53 pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) {
54 if (copy(other) != status::success) is_initialized_ = false;
55 }
56
57 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_bf16_1x1:", jcp_.isa, ""),
58 jit_avx512_core_bf16_1x1_convolution_fwd_t);
59
60 status_t init(engine_t *engine) {
61 bool ok = true && mayiuse(avx512_core) && is_fwd()
62 && set_default_alg_kind(alg_kind::convolution_direct)
63 && expect_data_types(data_type::bf16, data_type::bf16,
64 data_type::undef, dst_type, data_type::undef)
65 && IMPLICATION(with_bias(),
66 utils::one_of(weights_md(1)->data_type,
67 data_type::f32, data_type::bf16))
68 && attr()->has_default_values(
69 primitive_attr_t::skip_mask_t::post_ops, dst_type)
70 && !has_zero_dim_memory() && set_default_formats()
71 && attr_.set_default_formats(dst_md(0)) == status::success;
72 if (!ok) return status::unimplemented;
73
74 const convolution_desc_t *conv_d = desc();
75 const memory_desc_t *src_d = src_md();
76 rtus_prepare(this, conv_d, src_d, dst_md(), weights_md());
77
78 CHECK(jit_avx512_core_bf16_1x1_conv_kernel::init_conf(jcp_, *conv_d,
79 *src_d, *weights_md(), *dst_md(), attr_,
80 dnnl_get_max_threads(), rtus_.reduce_src_));
81
82 if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine));
83
84 auto scratchpad = scratchpad_registry().registrar();
85 CHECK(jit_avx512_core_bf16_1x1_conv_kernel::init_scratchpad(
86 scratchpad, jcp_));
87
88 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
89
90 return status::success;
91 }
92
93 const memory_desc_t *dst_md(int index = 0) const override {
94 return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_;
95 }
96
97 const memory_desc_t *arg_md(int index = 0) const override {
98 if (jcp_.with_dw_conv) {
99 switch (index) {
100 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
101 return dw_conv_pd_->weights_md(0);
102 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
103 return dw_conv_pd_->weights_md(1);
104 default: break;
105 }
106 }
107 return convolution_fwd_pd_t::arg_md(index);
108 }
109
110 arg_usage_t arg_usage(int arg) const override {
111 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
112 return arg_usage_t::input;
113
114 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
115 && attr_post_op_dw_inputs() > 1)
116 return arg_usage_t::input;
117
118 return convolution_fwd_pd_t::arg_usage(arg);
119 }
120
121 jit_1x1_conv_conf_t jcp_;
122 reduce_to_unit_stride_t rtus_;
123 jit_conv_conf_t *jcp_dw_; // doesn't own a resource
124 std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_;
125
126 protected:
127 template <data_type_t ddt>
128 using dw_pd_t = typename jit_uni_dw_convolution_fwd_t<avx512_core,
129 data_type::bf16, ddt>::pd_t;
130
131 bool set_default_formats() {
132 using namespace format_tag;
133
134 const memory_desc_wrapper src_d(&src_md_);
135 const memory_desc_wrapper dst_d(&dst_md_);
136
137 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
138 const auto dat_tag_nCx16c
139 = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
140 const auto curr_src_tag
141 = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
142 const auto curr_dst_tag
143 = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
144 const auto is_data_layout_nxc
145 = IMPLICATION(curr_src_tag != dat_tag_nxc,
146 src_d.format_kind() == format_kind::any)
147 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
148 dst_d.format_kind() == format_kind::any)
149 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
150
151 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
152 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
153 OIw8i16o2i, gOIw8i16o2i, OIhw8i16o2i, gOIhw8i16o2i,
154 OIdhw8i16o2i, gOIdhw8i16o2i);
155
156 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
157 }
158
159 status_t copy(const pd_t &other) {
160 jcp_ = other.jcp_;
161 rtus_ = other.rtus_;
162 jcp_dw_ = nullptr;
163 using namespace data_type;
164 if (other.dw_conv_pd_) {
165 dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>(
166 other.dw_conv_pd_->clone()));
167 if (!dw_conv_pd_) return status::out_of_memory;
168 auto dw_dst_dt = dw_conv_pd_->dst_md()->data_type;
169
170 switch (dw_dst_dt) {
171 case bf16:
172 jcp_dw_ = &(
173 static_cast<dw_pd_t<bf16> *>(dw_conv_pd_.get())
174 ->jcp_);
175 break;
176 case f32:
177 jcp_dw_ = &(
178 static_cast<dw_pd_t<f32> *>(dw_conv_pd_.get())
179 ->jcp_);
180 break;
181 default: assert(!"unreachable");
182 }
183 }
184 return status::success;
185 }
186
187 status_t depthwise_po_init(engine_t *engine) {
188 using namespace memory_tracking;
189 auto &jcp_1x1 = jcp_;
190 jit_conv_conf_t *jcp_dw = nullptr;
191 primitive_attr_t attr_1x1(*attr());
192 if (!attr_1x1.is_initialized()) return status::out_of_memory;
193
194 const auto &src_md = dst_md_;
195 const memory_desc_wrapper src_d(src_md);
196 const auto nthr = dnnl_get_max_threads();
197 auto l2_cache = platform::get_per_core_cache_size(2) * nthr;
198
199 // Note: A robust fusion implementation would be to check if both
200 // 1x1 conv and dw conv that are considered here for fusion are
201 // optimal independently. This would require creating a new
202 // primitive_desc through primitive_iterator & check if they match.
203 // Due to concern that these creations and/or checks could be heavy,
204 // for 1x1: Check that no better ISA is available.
205 // for dw: Always fuse with same ISA.
206 // Caveat: May be a better dw conv exists.
207
208 bool ok = !mayiuse(avx512_core_amx)
209 && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1)
210 // TODO: Below may be further tuned.
211 && (l2_cache * 2 < src_d.size())
212 // load_grp_count check can be redundant due to l2 check
213 // above. Adding it explicitly as the current driver doesn't
214 // work if this condition fails.
215 && (jcp_1x1.load_grp_count < 2);
216 if (!ok) return status::unimplemented;
217
218 int dw_po_index
219 = attr_1x1.post_ops_.find(primitive_kind::convolution);
220
221 convolution_desc_t cd_dw;
222 primitive_attr_t attr_dw;
223 CHECK(get_depthwise_conv_desc(
224 cd_dw, src_md, attr_1x1, attr_dw, dw_po_index));
225
226 auto dw_dst_dt = cd_dw.dst_desc.data_type;
227
228#define CASE(dt) \
229 case dt: { \
230 std::unique_ptr<dw_pd_t<dt>> fusable_pd( \
231 new dw_pd_t<dt>(&cd_dw, &attr_dw, nullptr)); \
232 CHECK(fusable_pd->init(engine)); \
233 jcp_dw = &(fusable_pd->jcp_); \
234 dw_conv_pd_ = std::move(fusable_pd); \
235 break; \
236 }
237 if (jcp_1x1.dst_dt == data_type::bf16) {
238 switch (dw_dst_dt) {
239 CASE(data_type::bf16);
240 CASE(data_type::f32);
241 default: return status::unimplemented;
242 }
243 } else
244 return status::unimplemented;
245#undef CASE
246
247 ok = true
248 && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0)))
249 && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0)
250 && IMPLICATION(
251 jcp_dw->ow_block, jcp_dw->ow_block == jcp_dw->ow);
252 if (!ok) return status::unimplemented;
253
254 assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any);
255 assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any);
256 assert(IMPLICATION(
257 dw_conv_pd_->weights_md(1)->data_type != data_type::undef,
258 dw_conv_pd_->weights_md(1)->format_kind
259 != format_kind::any));
260
261 jcp_dw->is_fused_conv = true;
262 // TODO: Support/experiment arbitary oc_work in dw conv.
263 // Until then we keep ch_work perfectly divisible.
264 while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0)
265 --jcp_1x1.nb_load_blocking;
266 jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking;
267
268 while (jcp_1x1.nb_load_blocking % jcp_dw->nb_ch_blocking != 0)
269 --jcp_dw->nb_ch_blocking;
270
271 jcp_dw->dw_conv_buffer_oc
272 = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block;
273
274 registrar_t scratchpad(scratchpad_registry_);
275 registrar_t dw_scratchpad(scratchpad, names::prefix_fusion);
276
277 size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw->kh * jcp_dw->iw
278 * jcp_dw->dw_conv_buffer_oc;
279 assert(dw_conv_buffer_size_);
280 dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
281 dw_conv_buffer_size_,
282 types::data_type_size(dw_conv_pd_->src_md()->data_type));
283
284 dw_conv_kernel_t::init_scratchpad(dw_scratchpad, *jcp_dw);
285
286 return status::success;
287 }
288 };
289
290 template <cpu_isa_t isa, typename conv_t>
291 friend status_t init_rtus_driver(conv_t *self);
292 jit_avx512_core_bf16_1x1_convolution_fwd_t(const pd_t *apd)
293 : primitive_t(apd) {}
294
295 typedef typename prec_traits<data_type::bf16>::type src_data_t;
296 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
297 typedef typename prec_traits<dst_type>::type dst_data_t;
298 // Note: In case of fused depthwise convolution, the final output datatype
299 // may not be dst_data_t.
300 typedef typename prec_traits<dst_type>::type dw_wei_data_t;
301
302 status_t init(engine_t *engine) override {
303 CHECK(safe_ptr_assign(kernel_,
304 new jit_avx512_core_bf16_1x1_conv_kernel(
305 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
306 CHECK(kernel_->create_kernel());
307
308 if (pd()->jcp_.with_dw_conv) {
309 CHECK(safe_ptr_assign(kernel_dw_,
310 new dw_conv_kernel_t(*(pd()->jcp_dw_), *pd()->dst_md(0))));
311 CHECK(kernel_dw_->create_kernel());
312 }
313
314 CHECK(init_rtus_driver<avx512_core>(this));
315 return status::success;
316 }
317
318 status_t execute(const exec_ctx_t &ctx) const override {
319 execute_forward(ctx);
320 return status::success;
321 }
322
323private:
324 void execute_forward(const exec_ctx_t &ctx) const;
325 void execute_forward_thr(const int ithr, const int nthr,
326 const src_data_t *src, const wei_data_t *weights, const char *bias,
327 const dw_wei_data_t *weights_dw, const float *bias_dw,
328 const char *dst, const memory_tracking::grantor_t &scratchpad,
329 const void *post_ops_binary_rhs_arg_vec,
330 const void *post_ops_binary_rhs_arg_vec_dw) const;
331 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
332
333 std::unique_ptr<jit_avx512_core_bf16_1x1_conv_kernel> kernel_;
334 std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_;
335 using dw_conv_kernel_t
336 = jit_uni_dw_conv_fwd_kernel<avx512_core, data_type::bf16>;
337 std::unique_ptr<dw_conv_kernel_t> kernel_dw_;
338};
339
340template <impl::data_type_t diff_src_type>
341struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t : public primitive_t {
342 struct pd_t : public cpu_convolution_bwd_data_pd_t {
343 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
344 const convolution_fwd_pd_t *hint_fwd_pd)
345 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd)
346 , jcp_()
347 , rtus_() {}
348
349 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_bf16_1x1:", jcp_.isa, ""),
350 jit_avx512_core_bf16_1x1_convolution_bwd_data_t);
351
352 status_t init(engine_t *engine) {
353 bool ok = true && mayiuse(avx512_core) && is_bwd_d()
354 && set_default_alg_kind(alg_kind::convolution_direct)
355 && expect_data_types(diff_src_type, data_type::bf16,
356 data_type::undef, data_type::bf16, data_type::undef)
357 && attr()->has_default_values() && !has_zero_dim_memory()
358 && set_default_formats();
359 if (!ok) return status::unimplemented;
360
361 const convolution_desc_t *conv_d = desc();
362 const memory_desc_t *diff_src_d = diff_src_md();
363 rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md());
364
365 status_t status = jit_avx512_core_bf16_1x1_conv_kernel::init_conf(
366 jcp_, *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(),
367 attr_, dnnl_get_max_threads(), rtus_.reduce_src_);
368 if (status != status::success) return status;
369
370 auto scratchpad = scratchpad_registry().registrar();
371 status = jit_avx512_core_bf16_1x1_conv_kernel::init_scratchpad(
372 scratchpad, jcp_);
373 if (status != status::success) return status;
374 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
375
376 return status::success;
377 }
378
379 // TODO (Roma): structs conf header cleanup
380 jit_1x1_conv_conf_t jcp_;
381 reduce_to_unit_stride_t rtus_;
382
383 protected:
384 bool set_default_formats() {
385 using namespace format_tag;
386
387 const memory_desc_wrapper diff_src_d(&diff_src_md_);
388 const memory_desc_wrapper diff_dst_d(&diff_dst_md_);
389
390 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
391 const auto dat_tag_nCx16c
392 = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
393 const auto curr_src_tag = diff_src_d.matches_one_of_tag(
394 dat_tag_nxc, dat_tag_nCx16c);
395 const auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
396 dat_tag_nxc, dat_tag_nCx16c);
397 const auto is_data_layout_nxc
398 = IMPLICATION(curr_src_tag != dat_tag_nxc,
399 diff_src_d.format_kind() == format_kind::any)
400 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
401 diff_dst_d.format_kind() == format_kind::any)
402 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
403 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
404 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
405 IOw8o16i2o, gIOw8o16i2o, IOhw8o16i2o, gIOhw8o16i2o,
406 IOdhw8o16i2o, gIOdhw8o16i2o);
407
408 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
409 }
410 };
411
412 template <cpu_isa_t isa, typename conv_t>
413 friend status_t init_rtus_driver(conv_t *self);
414
415 jit_avx512_core_bf16_1x1_convolution_bwd_data_t(const pd_t *apd)
416 : primitive_t(apd) {}
417
418 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
419 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
420 typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
421
422 status_t init(engine_t *engine) override {
423 CHECK(safe_ptr_assign(kernel_,
424 new jit_avx512_core_bf16_1x1_conv_kernel(
425 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
426 CHECK(kernel_->create_kernel());
427 CHECK(init_rtus_driver<avx512_core>(this));
428 return status::success;
429 }
430
431 status_t execute(const exec_ctx_t &ctx) const override {
432 execute_backward_data(ctx);
433 return status::success;
434 }
435
436private:
437 void execute_backward_data(const exec_ctx_t &ctx) const;
438 void execute_backward_data_thr(const int, const int,
439 const diff_dst_data_t *, const wei_data_t *, diff_src_data_t *,
440 const memory_tracking::grantor_t &scratchpad) const;
441 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
442
443 std::unique_ptr<jit_avx512_core_bf16_1x1_conv_kernel> kernel_;
444 /* reduction to unit stride */
445 std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_;
446};
447
448template <impl::data_type_t diff_weights_type>
449struct jit_avx512_core_bf16_1x1_convolution_bwd_weights_t : public primitive_t {
450 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
451 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
452 const convolution_fwd_pd_t *hint_fwd_pd)
453 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
454 , jcp_()
455 , rtus_() {}
456
457 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_bf16_1x1:", jcp_.isa, ""),
458 jit_avx512_core_bf16_1x1_convolution_bwd_weights_t);
459
460 status_t init(engine_t *engine) {
461 using namespace prop_kind;
462 assert(engine->kind() == engine_kind::cpu);
463 bool ok = true && mayiuse(avx512_core) && is_bwd_w()
464 && set_default_alg_kind(alg_kind::convolution_direct)
465 && expect_data_types(data_type::bf16, diff_weights_type,
466 data_type::undef, data_type::bf16, data_type::undef)
467 && IMPLICATION(with_bias(),
468 utils::one_of(diff_weights_md(1)->data_type,
469 data_type::f32, data_type::bf16))
470 && attr()->has_default_values() && !has_zero_dim_memory()
471 && set_default_formats();
472 if (!ok) return status::unimplemented;
473
474 const convolution_desc_t *conv_d = desc();
475 const memory_desc_t *src_d = src_md();
476 rtus_prepare(
477 this, conv_d, src_d, diff_dst_md(), diff_weights_md(0));
478
479 status_t status = jit_avx512_core_bf16_1x1_conv_kernel::init_conf(
480 jcp_, *conv_d, *src_d, *diff_weights_md(0), *diff_dst_md(),
481 attr_, dnnl_get_max_threads(), rtus_.reduce_src_);
482 if (status != status::success) return status;
483
484 auto scratchpad = scratchpad_registry().registrar();
485 status = jit_avx512_core_bf16_1x1_conv_kernel::init_scratchpad(
486 scratchpad, jcp_);
487 if (status != status::success) return status;
488
489 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
490
491 return status::success;
492 }
493
494 // TODO (Roma): structs conf header cleanup
495 jit_1x1_conv_conf_t jcp_;
496 reduce_to_unit_stride_t rtus_;
497
498 protected:
499 bool set_default_formats() {
500 using namespace format_tag;
501
502 const memory_desc_wrapper src_d(&src_md_);
503 const memory_desc_wrapper diff_dst_d(&diff_dst_md_);
504
505 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
506 const auto dat_tag_nCx16c
507 = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
508 const auto curr_src_tag
509 = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
510 const auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
511 dat_tag_nxc, dat_tag_nCx16c);
512 const auto is_data_layout_nxc
513 = IMPLICATION(curr_src_tag != dat_tag_nxc,
514 src_d.format_kind() == format_kind::any)
515 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
516 diff_dst_d.format_kind() == format_kind::any)
517 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
518 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
519 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
520 OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o,
521 gOIdhw16i16o);
522
523 bool ok = set_default_formats_common(dat_tag, wei_tag, dat_tag);
524 return ok;
525 }
526 };
527
528 template <cpu_isa_t isa, typename conv_t>
529 friend status_t init_rtus_driver(conv_t *self);
530
531 jit_avx512_core_bf16_1x1_convolution_bwd_weights_t(const pd_t *apd)
532 : primitive_t(apd) {}
533
534 status_t init(engine_t *engine) override;
535
536 status_t execute(const exec_ctx_t &ctx) const override {
537 execute_backward_weights(ctx);
538 return status::success;
539 }
540
541 typedef typename prec_traits<data_type::bf16>::type src_data_t;
542 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
543
544 typedef typename prec_traits<diff_weights_type>::type diff_wei_data_t;
545
546private:
547 void execute_backward_weights(const exec_ctx_t &ctx) const;
548 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
549
550 std::unique_ptr<jit_avx512_core_bf16_1x1_conv_kernel> kernel_;
551 std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_;
552
553 /* reduction to unit stride */
554 std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_;
555
556 std::unique_ptr<jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t> tr_reorder_;
557 std::unique_ptr<jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t>
558 tr_reorder_nhwc_src_;
559 std::unique_ptr<jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t>
560 tr_reorder_nhwc_ddst_;
561};
562
563} // namespace x64
564} // namespace cpu
565} // namespace impl
566} // namespace dnnl
567#endif
568