1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_COMMON_1X1_CONVOLUTION_HPP
18#define CPU_X64_JIT_AVX512_COMMON_1X1_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/primitive_hashing.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_convolution_pd.hpp"
28#include "cpu/dw_convolution_utils.hpp"
29#include "cpu/platform.hpp"
30
31#include "cpu/x64/cpu_reducer.hpp"
32#include "cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp"
33#include "cpu/x64/jit_transpose_utils.hpp"
34#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
35#include "cpu/x64/jit_uni_dw_convolution.hpp"
36
37namespace dnnl {
38namespace impl {
39namespace cpu {
40namespace x64 {
41
42template <impl::data_type_t src_type, impl::data_type_t wei_type = src_type,
43 impl::data_type_t dst_type = src_type>
44struct jit_avx512_common_1x1_convolution_fwd_t : public primitive_t {
45 struct pd_t : public cpu_convolution_fwd_pd_t {
46 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
47 const typename pd_t::base_class *hint_fwd_pd)
48 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
49 , jcp_()
50 , rtus_() {}
51
52 pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) {
53 if (copy(other) != status::success) is_initialized_ = false;
54 }
55
56 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_core, ""),
57 jit_avx512_common_1x1_convolution_fwd_t);
58
59 status_t init(engine_t *engine) {
60 using namespace utils;
61 bool ok = true && is_fwd()
62 && set_default_alg_kind(alg_kind::convolution_direct)
63 && expect_data_types(src_type, wei_type, dst_type, dst_type,
64 data_type::undef)
65 && attr()->has_default_values(
66 primitive_attr_t::skip_mask_t::post_ops, dst_type)
67 && !has_zero_dim_memory() && set_default_formats()
68 && attr_.set_default_formats(dst_md(0)) == status::success;
69 if (!ok) return status::unimplemented;
70
71 const convolution_desc_t *conv_d = desc();
72 const memory_desc_t *src_d = src_md();
73 rtus_prepare(this, conv_d, src_d, dst_md(), weights_md());
74
75 CHECK(jit_avx512_common_1x1_conv_kernel::init_conf(jcp_, *conv_d,
76 *src_d, *weights_md(), *dst_md(), *attr(),
77 dnnl_get_max_threads(), rtus_.reduce_src_));
78 if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine));
79
80 auto scratchpad = scratchpad_registry().registrar();
81 jit_avx512_common_1x1_conv_kernel::init_scratchpad(
82 scratchpad, jcp_);
83
84 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
85
86 return status::success;
87 }
88
89 const memory_desc_t *dst_md(int index = 0) const override {
90 return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_;
91 }
92
93 const memory_desc_t *arg_md(int index = 0) const override {
94 if (jcp_.with_dw_conv) {
95 switch (index) {
96 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
97 return dw_conv_pd_->weights_md(0);
98 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
99 return dw_conv_pd_->weights_md(1);
100 default: break;
101 }
102 }
103 return convolution_fwd_pd_t::arg_md(index);
104 }
105
106 arg_usage_t arg_usage(int arg) const override {
107 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
108 return arg_usage_t::input;
109
110 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
111 && attr_post_op_dw_inputs() > 1)
112 return arg_usage_t::input;
113
114 return convolution_fwd_pd_t::arg_usage(arg);
115 }
116
117 jit_1x1_conv_conf_t jcp_;
118 reduce_to_unit_stride_t rtus_;
119 using dw_pd_t = jit_avx512_common_dw_convolution_fwd_t::pd_t;
120 std::unique_ptr<dw_pd_t> dw_conv_pd_;
121
122 protected:
123 bool set_default_formats() {
124 using namespace format_tag;
125
126 const memory_desc_wrapper src_d(&src_md_);
127 const memory_desc_wrapper dst_d(&dst_md_);
128
129 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
130 const auto dat_tag_nCx16c
131 = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
132 const auto curr_src_tag
133 = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
134 const auto curr_dst_tag
135 = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
136 const auto is_data_layout_nxc
137 = IMPLICATION(curr_src_tag != dat_tag_nxc,
138 src_d.format_kind() == format_kind::any)
139 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
140 dst_d.format_kind() == format_kind::any)
141 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
142 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
143 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
144 OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o,
145 gOIdhw16i16o);
146
147 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
148 }
149
150 status_t copy(const pd_t &other) {
151 jcp_ = other.jcp_;
152 rtus_ = other.rtus_;
153 if (other.dw_conv_pd_) {
154 dw_conv_pd_.reset(other.dw_conv_pd_->clone());
155 if (!dw_conv_pd_) return status::out_of_memory;
156 }
157 return status::success;
158 }
159
160 status_t depthwise_po_init(engine_t *engine) {
161
162 using namespace memory_tracking;
163 auto &jcp_1x1 = jcp_;
164 primitive_attr_t attr_1x1(*attr());
165 if (!attr_1x1.is_initialized()) return status::out_of_memory;
166 const auto &src_md = dst_md_;
167 const memory_desc_wrapper src_d(src_md);
168 const auto nthr = dnnl_get_max_threads();
169 auto l2_cache = platform::get_per_core_cache_size(2) * nthr;
170
171 // Note: A robust fusion implementation would be to check if both
172 // 1x1 conv and dw conv that are considered here for fusion are
173 // optimal independently. This would require creating a new
174 // primitive_desc through primitive_iterator & check if they match.
175 // Due to concern that these creations and/or checks could be heavy,
176 // for 1x1: Check that no better ISA is available.
177 // for dw: Always fuse with same ISA.
178 // Caveat: May be a better dw conv exists.
179
180 // TODO: Add a check if better ISA exists following above note.
181 bool ok = true
182 && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1)
183 // TODO: Below may be further tuned.
184 && (l2_cache * 2 < src_d.size())
185 // load_grp_count check can be redundant due to l2 check
186 // above. Adding it explicitly as the current driver doesn't
187 // work if this condition fails.
188 && (jcp_1x1.load_grp_count < 2);
189 if (!ok) return status::unimplemented;
190
191 int dw_po_index
192 = attr_1x1.post_ops_.find(primitive_kind::convolution);
193 convolution_desc_t cd_dw;
194 primitive_attr_t attr_dw;
195 CHECK(get_depthwise_conv_desc(
196 cd_dw, src_md, attr_1x1, attr_dw, dw_po_index));
197
198 CHECK(safe_ptr_assign(
199 dw_conv_pd_, new dw_pd_t(&cd_dw, &attr_dw, nullptr)));
200 CHECK(dw_conv_pd_->init(engine));
201 auto &jcp_dw = dw_conv_pd_->jcp_;
202
203 ok = true
204 && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0)))
205 && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0)
206 && IMPLICATION(
207 jcp_dw.ow_block, jcp_dw.ow_block == jcp_dw.ow);
208 if (!ok) return status::unimplemented;
209
210 assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any);
211 assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any);
212 assert(IMPLICATION(
213 dw_conv_pd_->weights_md(1)->data_type != data_type::undef,
214 dw_conv_pd_->weights_md(1)->format_kind
215 != format_kind::any));
216
217 jcp_dw.is_fused_conv = true;
218 // TODO: Support/experiment arbitary oc_work in dw conv.
219 // Until then we keep oc_work perfectly divisible.
220 while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0)
221 --jcp_1x1.nb_load_blocking;
222 jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking;
223
224 while (jcp_1x1.nb_load_blocking % jcp_dw.nb_ch_blocking != 0)
225 --jcp_dw.nb_ch_blocking;
226
227 jcp_dw.dw_conv_buffer_oc
228 = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block;
229
230 const auto dat_tag_nxc = utils::pick(ndims() - 3, format_tag::nwc,
231 format_tag::nhwc, format_tag::ndhwc);
232 const bool is_data_nxc = utils::everyone_is(
233 dat_tag_nxc, jcp_1x1.src_tag, jcp_1x1.dst_tag);
234 if (!is_data_nxc)
235 jcp_1x1.bcast_loop_output_step = jcp_1x1.ur * jcp_1x1.load_block
236 * jcp_1x1.typesize_out;
237
238 registrar_t scratchpad(scratchpad_registry_);
239 registrar_t dw_scratchpad(scratchpad, names::prefix_fusion);
240
241 size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw.kh * jcp_dw.iw
242 * jcp_dw.dw_conv_buffer_oc;
243 assert(dw_conv_buffer_size_);
244 dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
245 dw_conv_buffer_size_,
246 types::data_type_size(dw_conv_pd_->src_md()->data_type));
247
248 jit_uni_dw_conv_fwd_kernel<avx512_core,
249 data_type::f32>::init_scratchpad(dw_scratchpad, jcp_dw);
250
251 return status::success;
252 }
253 };
254
255 template <cpu_isa_t isa, typename conv_t>
256 friend status_t init_rtus_driver(conv_t *self);
257
258 jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd)
259 : primitive_t(apd) {}
260
261 typedef typename prec_traits<src_type>::type src_data_t;
262 typedef typename prec_traits<wei_type>::type wei_data_t;
263 typedef typename prec_traits<dst_type>::type dst_data_t;
264
265 status_t init(engine_t *engine) override {
266 CHECK(safe_ptr_assign(kernel_,
267 new jit_avx512_common_1x1_conv_kernel(
268 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
269 CHECK(kernel_->create_kernel());
270
271 if (pd()->jcp_.with_dw_conv) {
272 CHECK(safe_ptr_assign(kernel_dw_,
273 new dw_conv_kernel_t(
274 pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0))));
275 CHECK(kernel_dw_->create_kernel());
276 }
277
278 CHECK(init_rtus_driver<avx512_core>(this));
279 return status::success;
280 }
281
282 status_t execute(const exec_ctx_t &ctx) const override {
283 execute_forward(ctx);
284 return status::success;
285 }
286
287private:
288 void execute_forward(const exec_ctx_t &ctx) const;
289 void execute_forward_thr(const int ithr, const int nthr,
290 const src_data_t *src, const wei_data_t *weights,
291 const dst_data_t *bias, const wei_data_t *weights_dw,
292 const dst_data_t *bias_dw, dst_data_t *dst,
293 const memory_tracking::grantor_t &scratchpad,
294 const void *post_ops_binary_rhs_arg_vec,
295 const void *post_ops_binary_rhs_arg_vec_dw) const;
296 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
297
298 std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_;
299 std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_;
300 using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel_f32<avx512_core>;
301 std::unique_ptr<dw_conv_kernel_t> kernel_dw_;
302};
303
304using jit_avx512_common_1x1_convolution_fwd_f32_t
305 = jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
306
307template <impl::data_type_t diff_dst_type,
308 impl::data_type_t wei_type = diff_dst_type,
309 impl::data_type_t diff_src_type = diff_dst_type>
310struct jit_avx512_common_1x1_convolution_bwd_data_t : public primitive_t {
311 struct pd_t : public cpu_convolution_bwd_data_pd_t {
312 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
313 const convolution_fwd_pd_t *hint_fwd_pd)
314 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd)
315 , jcp_()
316 , rtus_() {}
317
318 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_core, ""),
319 jit_avx512_common_1x1_convolution_bwd_data_t);
320
321 status_t init(engine_t *engine) {
322 bool ok = true && desc()->prop_kind == prop_kind::backward_data
323 && set_default_alg_kind(alg_kind::convolution_direct)
324 && expect_data_types(diff_src_type, wei_type,
325 data_type::undef, diff_dst_type, data_type::undef)
326 && attr()->has_default_values() && !has_zero_dim_memory()
327 && set_default_formats();
328 if (!ok) return status::unimplemented;
329
330 const convolution_desc_t *conv_d = desc();
331 const memory_desc_t *diff_src_d = diff_src_md();
332 rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md());
333
334 status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_,
335 *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(),
336 *attr(), dnnl_get_max_threads(), rtus_.reduce_src_);
337 if (status != status::success) return status;
338
339 auto scratchpad = scratchpad_registry().registrar();
340 jit_avx512_common_1x1_conv_kernel::init_scratchpad(
341 scratchpad, jcp_);
342
343 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
344
345 return status::success;
346 }
347
348 // TODO (Roma): structs conf header cleanup
349 jit_1x1_conv_conf_t jcp_;
350 reduce_to_unit_stride_t rtus_;
351
352 protected:
353 bool set_default_formats() {
354 using namespace format_tag;
355
356 const memory_desc_wrapper diff_src_d(&diff_src_md_);
357 const memory_desc_wrapper diff_dst_d(&diff_dst_md_);
358
359 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
360 const auto dat_tag_nCx16c
361 = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
362 const auto curr_src_tag = diff_src_d.matches_one_of_tag(
363 dat_tag_nxc, dat_tag_nCx16c);
364 const auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
365 dat_tag_nxc, dat_tag_nCx16c);
366 const auto is_data_layout_nxc
367 = IMPLICATION(curr_src_tag != dat_tag_nxc,
368 diff_src_d.format_kind() == format_kind::any)
369 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
370 diff_dst_d.format_kind() == format_kind::any)
371 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
372 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
373 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
374 IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i, IOdhw16o16i,
375 gIOdhw16o16i);
376
377 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
378 }
379 };
380
381 template <cpu_isa_t isa, typename conv_t>
382 friend status_t init_rtus_driver(conv_t *self);
383
384 jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd)
385 : primitive_t(apd) {}
386
387 typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
388 typedef typename prec_traits<wei_type>::type wei_data_t;
389 typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
390
391 status_t init(engine_t *engine) override {
392 CHECK(safe_ptr_assign(kernel_,
393 new jit_avx512_common_1x1_conv_kernel(
394 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
395 CHECK(kernel_->create_kernel());
396 CHECK(init_rtus_driver<avx512_core>(this));
397 return status::success;
398 }
399
400 status_t execute(const exec_ctx_t &ctx) const override {
401 execute_backward_data(ctx);
402 return status::success;
403 }
404
405private:
406 void execute_backward_data(const exec_ctx_t &ctx) const;
407 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
408
409 std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_;
410 std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_;
411};
412
413using jit_avx512_common_1x1_convolution_bwd_data_f32_t
414 = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
415
416struct jit_avx512_common_1x1_convolution_bwd_weights_t : public primitive_t {
417 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
418 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
419 const convolution_fwd_pd_t *hint_fwd_pd)
420 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
421 , jcp_()
422 , rtus_() {}
423
424 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_core, ""),
425 jit_avx512_common_1x1_convolution_bwd_weights_t);
426
427 status_t init(engine_t *engine) {
428 bool ok = true && desc()->prop_kind == prop_kind::backward_weights
429 && set_default_alg_kind(alg_kind::convolution_direct)
430 && expect_data_types(data_type::f32, data_type::f32,
431 data_type::f32, data_type::f32, data_type::f32)
432 && attr()->has_default_values() && !has_zero_dim_memory()
433 && set_default_formats();
434 if (!ok) return status::unimplemented;
435
436 const convolution_desc_t *conv_d = desc();
437 const memory_desc_t *src_d = src_md();
438 rtus_prepare(this, conv_d, src_d, diff_dst_md(), diff_weights_md());
439
440 status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_,
441 *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(),
442 *attr(), dnnl_get_max_threads(), rtus_.reduce_src_);
443 if (status != status::success) return status;
444
445 init_balancers();
446
447 auto scratchpad = scratchpad_registry().registrar();
448 jit_avx512_common_1x1_conv_kernel::init_scratchpad(
449 scratchpad, jcp_);
450
451 auto reducer_bia_scratchpad = memory_tracking::registrar_t(
452 scratchpad, memory_tracking::names::prefix_reducer_bia);
453 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
454
455 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
456
457 return status::success;
458 }
459
460 // TODO (Roma): structs conf header cleanup
461 jit_1x1_conv_conf_t jcp_;
462 cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
463 reduce_to_unit_stride_t rtus_;
464
465 protected:
466 bool set_default_formats() {
467 using namespace format_tag;
468
469 const memory_desc_wrapper src_d(&src_md_);
470 const memory_desc_wrapper diff_dst_d(&diff_dst_md_);
471
472 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
473 const auto dat_tag_nCx16c
474 = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
475 const auto curr_src_tag
476 = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
477 const auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
478 dat_tag_nxc, dat_tag_nCx16c);
479 const auto is_data_layout_nxc
480 = IMPLICATION(curr_src_tag != dat_tag_nxc,
481 src_d.format_kind() == format_kind::any)
482 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
483 diff_dst_d.format_kind() == format_kind::any)
484 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
485
486 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
487 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
488 OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o,
489 gOIdhw16i16o);
490
491 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
492 }
493
494 private:
495 void init_balancers() {
496 const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
497 if (with_bias()) {
498 reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
499 jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, jcp_.mb,
500 max_buffer_size, true));
501 }
502 }
503 };
504
505 template <cpu_isa_t isa, typename conv_t>
506 friend status_t init_rtus_driver(conv_t *self);
507
508 jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd)
509 : primitive_t(apd) {}
510
511 typedef typename prec_traits<data_type::f32>::type data_t;
512
513 status_t init(engine_t *engine) override;
514
515 status_t execute(const exec_ctx_t &ctx) const override {
516 execute_backward_weights(ctx);
517 return status::success;
518 }
519
520private:
521 void execute_backward_weights(const exec_ctx_t &ctx) const;
522 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
523
524 std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_;
525 std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_;
526 std::unique_ptr<cpu_reducer_t<data_type::f32>> reducer_bias_;
527 std::unique_ptr<jit_transpose4x16_src> trans_kernel_;
528 std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_;
529};
530
531} // namespace x64
532} // namespace cpu
533} // namespace impl
534} // namespace dnnl
535
536#endif
537