1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_BF16_1X1_CONVOLUTION_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_BF16_1X1_CONVOLUTION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/primitive.hpp" |
23 | #include "common/primitive_hashing.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/cpu_convolution_pd.hpp" |
27 | #include "cpu/cpu_engine.hpp" |
28 | #include "cpu/dw_convolution_utils.hpp" |
29 | #include "cpu/platform.hpp" |
30 | |
31 | #include "cpu/x64/cpu_reducer.hpp" |
32 | #include "cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp" |
33 | #include "cpu/x64/jit_transpose_utils.hpp" |
34 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
35 | #include "cpu/x64/jit_uni_dw_convolution.hpp" |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | namespace cpu { |
40 | namespace x64 { |
41 | |
42 | template <impl::data_type_t dst_type> |
43 | struct jit_avx512_core_bf16_1x1_convolution_fwd_t : public primitive_t { |
44 | struct pd_t : public cpu_convolution_fwd_pd_t { |
45 | using dw_conv_pd_type = cpu_convolution_fwd_pd_t; |
46 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
47 | const typename pd_t::base_class *hint_fwd_pd) |
48 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) |
49 | , jcp_() |
50 | , rtus_() |
51 | , jcp_dw_(nullptr) {} |
52 | |
53 | pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { |
54 | if (copy(other) != status::success) is_initialized_ = false; |
55 | } |
56 | |
57 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_bf16_1x1:" , jcp_.isa, "" ), |
58 | jit_avx512_core_bf16_1x1_convolution_fwd_t); |
59 | |
60 | status_t init(engine_t *engine) { |
61 | bool ok = true && mayiuse(avx512_core) && is_fwd() |
62 | && set_default_alg_kind(alg_kind::convolution_direct) |
63 | && expect_data_types(data_type::bf16, data_type::bf16, |
64 | data_type::undef, dst_type, data_type::undef) |
65 | && IMPLICATION(with_bias(), |
66 | utils::one_of(weights_md(1)->data_type, |
67 | data_type::f32, data_type::bf16)) |
68 | && attr()->has_default_values( |
69 | primitive_attr_t::skip_mask_t::post_ops, dst_type) |
70 | && !has_zero_dim_memory() && set_default_formats() |
71 | && attr_.set_default_formats(dst_md(0)) == status::success; |
72 | if (!ok) return status::unimplemented; |
73 | |
74 | const convolution_desc_t *conv_d = desc(); |
75 | const memory_desc_t *src_d = src_md(); |
76 | rtus_prepare(this, conv_d, src_d, dst_md(), weights_md()); |
77 | |
78 | CHECK(jit_avx512_core_bf16_1x1_conv_kernel::init_conf(jcp_, *conv_d, |
79 | *src_d, *weights_md(), *dst_md(), attr_, |
80 | dnnl_get_max_threads(), rtus_.reduce_src_)); |
81 | |
82 | if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); |
83 | |
84 | auto scratchpad = scratchpad_registry().registrar(); |
85 | CHECK(jit_avx512_core_bf16_1x1_conv_kernel::init_scratchpad( |
86 | scratchpad, jcp_)); |
87 | |
88 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
89 | |
90 | return status::success; |
91 | } |
92 | |
93 | const memory_desc_t *dst_md(int index = 0) const override { |
94 | return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_; |
95 | } |
96 | |
97 | const memory_desc_t *arg_md(int index = 0) const override { |
98 | if (jcp_.with_dw_conv) { |
99 | switch (index) { |
100 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: |
101 | return dw_conv_pd_->weights_md(0); |
102 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: |
103 | return dw_conv_pd_->weights_md(1); |
104 | default: break; |
105 | } |
106 | } |
107 | return convolution_fwd_pd_t::arg_md(index); |
108 | } |
109 | |
110 | arg_usage_t arg_usage(int arg) const override { |
111 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) |
112 | return arg_usage_t::input; |
113 | |
114 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) |
115 | && attr_post_op_dw_inputs() > 1) |
116 | return arg_usage_t::input; |
117 | |
118 | return convolution_fwd_pd_t::arg_usage(arg); |
119 | } |
120 | |
121 | jit_1x1_conv_conf_t jcp_; |
122 | reduce_to_unit_stride_t rtus_; |
123 | jit_conv_conf_t *jcp_dw_; // doesn't own a resource |
124 | std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_; |
125 | |
126 | protected: |
127 | template <data_type_t ddt> |
128 | using dw_pd_t = typename jit_uni_dw_convolution_fwd_t<avx512_core, |
129 | data_type::bf16, ddt>::pd_t; |
130 | |
131 | bool set_default_formats() { |
132 | using namespace format_tag; |
133 | |
134 | const memory_desc_wrapper src_d(&src_md_); |
135 | const memory_desc_wrapper dst_d(&dst_md_); |
136 | |
137 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
138 | const auto dat_tag_nCx16c |
139 | = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); |
140 | const auto curr_src_tag |
141 | = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
142 | const auto curr_dst_tag |
143 | = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
144 | const auto is_data_layout_nxc |
145 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
146 | src_d.format_kind() == format_kind::any) |
147 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
148 | dst_d.format_kind() == format_kind::any) |
149 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
150 | |
151 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
152 | auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), |
153 | OIw8i16o2i, gOIw8i16o2i, OIhw8i16o2i, gOIhw8i16o2i, |
154 | OIdhw8i16o2i, gOIdhw8i16o2i); |
155 | |
156 | return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
157 | } |
158 | |
159 | status_t copy(const pd_t &other) { |
160 | jcp_ = other.jcp_; |
161 | rtus_ = other.rtus_; |
162 | jcp_dw_ = nullptr; |
163 | using namespace data_type; |
164 | if (other.dw_conv_pd_) { |
165 | dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>( |
166 | other.dw_conv_pd_->clone())); |
167 | if (!dw_conv_pd_) return status::out_of_memory; |
168 | auto dw_dst_dt = dw_conv_pd_->dst_md()->data_type; |
169 | |
170 | switch (dw_dst_dt) { |
171 | case bf16: |
172 | jcp_dw_ = &( |
173 | static_cast<dw_pd_t<bf16> *>(dw_conv_pd_.get()) |
174 | ->jcp_); |
175 | break; |
176 | case f32: |
177 | jcp_dw_ = &( |
178 | static_cast<dw_pd_t<f32> *>(dw_conv_pd_.get()) |
179 | ->jcp_); |
180 | break; |
181 | default: assert(!"unreachable" ); |
182 | } |
183 | } |
184 | return status::success; |
185 | } |
186 | |
187 | status_t depthwise_po_init(engine_t *engine) { |
188 | using namespace memory_tracking; |
189 | auto &jcp_1x1 = jcp_; |
190 | jit_conv_conf_t *jcp_dw = nullptr; |
191 | primitive_attr_t attr_1x1(*attr()); |
192 | if (!attr_1x1.is_initialized()) return status::out_of_memory; |
193 | |
194 | const auto &src_md = dst_md_; |
195 | const memory_desc_wrapper src_d(src_md); |
196 | const auto nthr = dnnl_get_max_threads(); |
197 | auto l2_cache = platform::get_per_core_cache_size(2) * nthr; |
198 | |
199 | // Note: A robust fusion implementation would be to check if both |
200 | // 1x1 conv and dw conv that are considered here for fusion are |
201 | // optimal independently. This would require creating a new |
202 | // primitive_desc through primitive_iterator & check if they match. |
203 | // Due to concern that these creations and/or checks could be heavy, |
204 | // for 1x1: Check that no better ISA is available. |
205 | // for dw: Always fuse with same ISA. |
206 | // Caveat: May be a better dw conv exists. |
207 | |
208 | bool ok = !mayiuse(avx512_core_amx) |
209 | && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) |
210 | // TODO: Below may be further tuned. |
211 | && (l2_cache * 2 < src_d.size()) |
212 | // load_grp_count check can be redundant due to l2 check |
213 | // above. Adding it explicitly as the current driver doesn't |
214 | // work if this condition fails. |
215 | && (jcp_1x1.load_grp_count < 2); |
216 | if (!ok) return status::unimplemented; |
217 | |
218 | int dw_po_index |
219 | = attr_1x1.post_ops_.find(primitive_kind::convolution); |
220 | |
221 | convolution_desc_t cd_dw; |
222 | primitive_attr_t attr_dw; |
223 | CHECK(get_depthwise_conv_desc( |
224 | cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); |
225 | |
226 | auto dw_dst_dt = cd_dw.dst_desc.data_type; |
227 | |
228 | #define CASE(dt) \ |
229 | case dt: { \ |
230 | std::unique_ptr<dw_pd_t<dt>> fusable_pd( \ |
231 | new dw_pd_t<dt>(&cd_dw, &attr_dw, nullptr)); \ |
232 | CHECK(fusable_pd->init(engine)); \ |
233 | jcp_dw = &(fusable_pd->jcp_); \ |
234 | dw_conv_pd_ = std::move(fusable_pd); \ |
235 | break; \ |
236 | } |
237 | if (jcp_1x1.dst_dt == data_type::bf16) { |
238 | switch (dw_dst_dt) { |
239 | CASE(data_type::bf16); |
240 | CASE(data_type::f32); |
241 | default: return status::unimplemented; |
242 | } |
243 | } else |
244 | return status::unimplemented; |
245 | #undef CASE |
246 | |
247 | ok = true |
248 | && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) |
249 | && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) |
250 | && IMPLICATION( |
251 | jcp_dw->ow_block, jcp_dw->ow_block == jcp_dw->ow); |
252 | if (!ok) return status::unimplemented; |
253 | |
254 | assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); |
255 | assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); |
256 | assert(IMPLICATION( |
257 | dw_conv_pd_->weights_md(1)->data_type != data_type::undef, |
258 | dw_conv_pd_->weights_md(1)->format_kind |
259 | != format_kind::any)); |
260 | |
261 | jcp_dw->is_fused_conv = true; |
262 | // TODO: Support/experiment arbitary oc_work in dw conv. |
263 | // Until then we keep ch_work perfectly divisible. |
264 | while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) |
265 | --jcp_1x1.nb_load_blocking; |
266 | jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; |
267 | |
268 | while (jcp_1x1.nb_load_blocking % jcp_dw->nb_ch_blocking != 0) |
269 | --jcp_dw->nb_ch_blocking; |
270 | |
271 | jcp_dw->dw_conv_buffer_oc |
272 | = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; |
273 | |
274 | registrar_t scratchpad(scratchpad_registry_); |
275 | registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); |
276 | |
277 | size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw->kh * jcp_dw->iw |
278 | * jcp_dw->dw_conv_buffer_oc; |
279 | assert(dw_conv_buffer_size_); |
280 | dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, |
281 | dw_conv_buffer_size_, |
282 | types::data_type_size(dw_conv_pd_->src_md()->data_type)); |
283 | |
284 | dw_conv_kernel_t::init_scratchpad(dw_scratchpad, *jcp_dw); |
285 | |
286 | return status::success; |
287 | } |
288 | }; |
289 | |
290 | template <cpu_isa_t isa, typename conv_t> |
291 | friend status_t init_rtus_driver(conv_t *self); |
292 | jit_avx512_core_bf16_1x1_convolution_fwd_t(const pd_t *apd) |
293 | : primitive_t(apd) {} |
294 | |
295 | typedef typename prec_traits<data_type::bf16>::type src_data_t; |
296 | typedef typename prec_traits<data_type::bf16>::type wei_data_t; |
297 | typedef typename prec_traits<dst_type>::type dst_data_t; |
298 | // Note: In case of fused depthwise convolution, the final output datatype |
299 | // may not be dst_data_t. |
300 | typedef typename prec_traits<dst_type>::type dw_wei_data_t; |
301 | |
302 | status_t init(engine_t *engine) override { |
303 | CHECK(safe_ptr_assign(kernel_, |
304 | new jit_avx512_core_bf16_1x1_conv_kernel( |
305 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
306 | CHECK(kernel_->create_kernel()); |
307 | |
308 | if (pd()->jcp_.with_dw_conv) { |
309 | CHECK(safe_ptr_assign(kernel_dw_, |
310 | new dw_conv_kernel_t(*(pd()->jcp_dw_), *pd()->dst_md(0)))); |
311 | CHECK(kernel_dw_->create_kernel()); |
312 | } |
313 | |
314 | CHECK(init_rtus_driver<avx512_core>(this)); |
315 | return status::success; |
316 | } |
317 | |
318 | status_t execute(const exec_ctx_t &ctx) const override { |
319 | execute_forward(ctx); |
320 | return status::success; |
321 | } |
322 | |
323 | private: |
324 | void execute_forward(const exec_ctx_t &ctx) const; |
325 | void execute_forward_thr(const int ithr, const int nthr, |
326 | const src_data_t *src, const wei_data_t *weights, const char *bias, |
327 | const dw_wei_data_t *weights_dw, const float *bias_dw, |
328 | const char *dst, const memory_tracking::grantor_t &scratchpad, |
329 | const void *post_ops_binary_rhs_arg_vec, |
330 | const void *post_ops_binary_rhs_arg_vec_dw) const; |
331 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
332 | |
333 | std::unique_ptr<jit_avx512_core_bf16_1x1_conv_kernel> kernel_; |
334 | std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_; |
335 | using dw_conv_kernel_t |
336 | = jit_uni_dw_conv_fwd_kernel<avx512_core, data_type::bf16>; |
337 | std::unique_ptr<dw_conv_kernel_t> kernel_dw_; |
338 | }; |
339 | |
340 | template <impl::data_type_t diff_src_type> |
341 | struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t : public primitive_t { |
342 | struct pd_t : public cpu_convolution_bwd_data_pd_t { |
343 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
344 | const convolution_fwd_pd_t *hint_fwd_pd) |
345 | : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) |
346 | , jcp_() |
347 | , rtus_() {} |
348 | |
349 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_bf16_1x1:" , jcp_.isa, "" ), |
350 | jit_avx512_core_bf16_1x1_convolution_bwd_data_t); |
351 | |
352 | status_t init(engine_t *engine) { |
353 | bool ok = true && mayiuse(avx512_core) && is_bwd_d() |
354 | && set_default_alg_kind(alg_kind::convolution_direct) |
355 | && expect_data_types(diff_src_type, data_type::bf16, |
356 | data_type::undef, data_type::bf16, data_type::undef) |
357 | && attr()->has_default_values() && !has_zero_dim_memory() |
358 | && set_default_formats(); |
359 | if (!ok) return status::unimplemented; |
360 | |
361 | const convolution_desc_t *conv_d = desc(); |
362 | const memory_desc_t *diff_src_d = diff_src_md(); |
363 | rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md()); |
364 | |
365 | status_t status = jit_avx512_core_bf16_1x1_conv_kernel::init_conf( |
366 | jcp_, *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), |
367 | attr_, dnnl_get_max_threads(), rtus_.reduce_src_); |
368 | if (status != status::success) return status; |
369 | |
370 | auto scratchpad = scratchpad_registry().registrar(); |
371 | status = jit_avx512_core_bf16_1x1_conv_kernel::init_scratchpad( |
372 | scratchpad, jcp_); |
373 | if (status != status::success) return status; |
374 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
375 | |
376 | return status::success; |
377 | } |
378 | |
379 | // TODO (Roma): structs conf header cleanup |
380 | jit_1x1_conv_conf_t jcp_; |
381 | reduce_to_unit_stride_t rtus_; |
382 | |
383 | protected: |
384 | bool set_default_formats() { |
385 | using namespace format_tag; |
386 | |
387 | const memory_desc_wrapper diff_src_d(&diff_src_md_); |
388 | const memory_desc_wrapper diff_dst_d(&diff_dst_md_); |
389 | |
390 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
391 | const auto dat_tag_nCx16c |
392 | = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); |
393 | const auto curr_src_tag = diff_src_d.matches_one_of_tag( |
394 | dat_tag_nxc, dat_tag_nCx16c); |
395 | const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( |
396 | dat_tag_nxc, dat_tag_nCx16c); |
397 | const auto is_data_layout_nxc |
398 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
399 | diff_src_d.format_kind() == format_kind::any) |
400 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
401 | diff_dst_d.format_kind() == format_kind::any) |
402 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
403 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
404 | auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), |
405 | IOw8o16i2o, gIOw8o16i2o, IOhw8o16i2o, gIOhw8o16i2o, |
406 | IOdhw8o16i2o, gIOdhw8o16i2o); |
407 | |
408 | return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
409 | } |
410 | }; |
411 | |
412 | template <cpu_isa_t isa, typename conv_t> |
413 | friend status_t init_rtus_driver(conv_t *self); |
414 | |
415 | jit_avx512_core_bf16_1x1_convolution_bwd_data_t(const pd_t *apd) |
416 | : primitive_t(apd) {} |
417 | |
418 | typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t; |
419 | typedef typename prec_traits<data_type::bf16>::type wei_data_t; |
420 | typedef typename prec_traits<diff_src_type>::type diff_src_data_t; |
421 | |
422 | status_t init(engine_t *engine) override { |
423 | CHECK(safe_ptr_assign(kernel_, |
424 | new jit_avx512_core_bf16_1x1_conv_kernel( |
425 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
426 | CHECK(kernel_->create_kernel()); |
427 | CHECK(init_rtus_driver<avx512_core>(this)); |
428 | return status::success; |
429 | } |
430 | |
431 | status_t execute(const exec_ctx_t &ctx) const override { |
432 | execute_backward_data(ctx); |
433 | return status::success; |
434 | } |
435 | |
436 | private: |
437 | void execute_backward_data(const exec_ctx_t &ctx) const; |
438 | void execute_backward_data_thr(const int, const int, |
439 | const diff_dst_data_t *, const wei_data_t *, diff_src_data_t *, |
440 | const memory_tracking::grantor_t &scratchpad) const; |
441 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
442 | |
443 | std::unique_ptr<jit_avx512_core_bf16_1x1_conv_kernel> kernel_; |
444 | /* reduction to unit stride */ |
445 | std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_; |
446 | }; |
447 | |
448 | template <impl::data_type_t diff_weights_type> |
449 | struct jit_avx512_core_bf16_1x1_convolution_bwd_weights_t : public primitive_t { |
450 | struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
451 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
452 | const convolution_fwd_pd_t *hint_fwd_pd) |
453 | : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) |
454 | , jcp_() |
455 | , rtus_() {} |
456 | |
457 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_bf16_1x1:" , jcp_.isa, "" ), |
458 | jit_avx512_core_bf16_1x1_convolution_bwd_weights_t); |
459 | |
460 | status_t init(engine_t *engine) { |
461 | using namespace prop_kind; |
462 | assert(engine->kind() == engine_kind::cpu); |
463 | bool ok = true && mayiuse(avx512_core) && is_bwd_w() |
464 | && set_default_alg_kind(alg_kind::convolution_direct) |
465 | && expect_data_types(data_type::bf16, diff_weights_type, |
466 | data_type::undef, data_type::bf16, data_type::undef) |
467 | && IMPLICATION(with_bias(), |
468 | utils::one_of(diff_weights_md(1)->data_type, |
469 | data_type::f32, data_type::bf16)) |
470 | && attr()->has_default_values() && !has_zero_dim_memory() |
471 | && set_default_formats(); |
472 | if (!ok) return status::unimplemented; |
473 | |
474 | const convolution_desc_t *conv_d = desc(); |
475 | const memory_desc_t *src_d = src_md(); |
476 | rtus_prepare( |
477 | this, conv_d, src_d, diff_dst_md(), diff_weights_md(0)); |
478 | |
479 | status_t status = jit_avx512_core_bf16_1x1_conv_kernel::init_conf( |
480 | jcp_, *conv_d, *src_d, *diff_weights_md(0), *diff_dst_md(), |
481 | attr_, dnnl_get_max_threads(), rtus_.reduce_src_); |
482 | if (status != status::success) return status; |
483 | |
484 | auto scratchpad = scratchpad_registry().registrar(); |
485 | status = jit_avx512_core_bf16_1x1_conv_kernel::init_scratchpad( |
486 | scratchpad, jcp_); |
487 | if (status != status::success) return status; |
488 | |
489 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
490 | |
491 | return status::success; |
492 | } |
493 | |
494 | // TODO (Roma): structs conf header cleanup |
495 | jit_1x1_conv_conf_t jcp_; |
496 | reduce_to_unit_stride_t rtus_; |
497 | |
498 | protected: |
499 | bool set_default_formats() { |
500 | using namespace format_tag; |
501 | |
502 | const memory_desc_wrapper src_d(&src_md_); |
503 | const memory_desc_wrapper diff_dst_d(&diff_dst_md_); |
504 | |
505 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
506 | const auto dat_tag_nCx16c |
507 | = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); |
508 | const auto curr_src_tag |
509 | = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
510 | const auto curr_dst_tag = diff_dst_d.matches_one_of_tag( |
511 | dat_tag_nxc, dat_tag_nCx16c); |
512 | const auto is_data_layout_nxc |
513 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
514 | src_d.format_kind() == format_kind::any) |
515 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
516 | diff_dst_d.format_kind() == format_kind::any) |
517 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
518 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
519 | auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), |
520 | OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o, |
521 | gOIdhw16i16o); |
522 | |
523 | bool ok = set_default_formats_common(dat_tag, wei_tag, dat_tag); |
524 | return ok; |
525 | } |
526 | }; |
527 | |
528 | template <cpu_isa_t isa, typename conv_t> |
529 | friend status_t init_rtus_driver(conv_t *self); |
530 | |
531 | jit_avx512_core_bf16_1x1_convolution_bwd_weights_t(const pd_t *apd) |
532 | : primitive_t(apd) {} |
533 | |
534 | status_t init(engine_t *engine) override; |
535 | |
536 | status_t execute(const exec_ctx_t &ctx) const override { |
537 | execute_backward_weights(ctx); |
538 | return status::success; |
539 | } |
540 | |
541 | typedef typename prec_traits<data_type::bf16>::type src_data_t; |
542 | typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t; |
543 | |
544 | typedef typename prec_traits<diff_weights_type>::type diff_wei_data_t; |
545 | |
546 | private: |
547 | void execute_backward_weights(const exec_ctx_t &ctx) const; |
548 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
549 | |
550 | std::unique_ptr<jit_avx512_core_bf16_1x1_conv_kernel> kernel_; |
551 | std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; |
552 | |
553 | /* reduction to unit stride */ |
554 | std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_; |
555 | |
556 | std::unique_ptr<jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t> tr_reorder_; |
557 | std::unique_ptr<jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t> |
558 | tr_reorder_nhwc_src_; |
559 | std::unique_ptr<jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t> |
560 | tr_reorder_nhwc_ddst_; |
561 | }; |
562 | |
563 | } // namespace x64 |
564 | } // namespace cpu |
565 | } // namespace impl |
566 | } // namespace dnnl |
567 | #endif |
568 | |