1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX2_1X1_CONVOLUTION_HPP |
18 | #define CPU_X64_JIT_AVX2_1X1_CONVOLUTION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/primitive_hashing.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | #include "cpu/cpu_convolution_pd.hpp" |
28 | #include "cpu/dw_convolution_utils.hpp" |
29 | #include "cpu/platform.hpp" |
30 | |
31 | #include "cpu/x64/cpu_reducer.hpp" |
32 | #include "cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp" |
33 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
34 | #include "cpu/x64/jit_uni_dw_convolution.hpp" |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace cpu { |
39 | namespace x64 { |
40 | |
41 | struct jit_avx2_1x1_convolution_fwd_t : public primitive_t { |
42 | // TODO: (Roma) Code duplication duplication! Remove with templates |
43 | // (maybe...)! |
44 | struct pd_t : public cpu_convolution_fwd_pd_t { |
45 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
46 | const typename pd_t::base_class *hint_fwd_pd) |
47 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) |
48 | , jcp_() |
49 | , rtus_() |
50 | , jcp_dw_(nullptr) {} |
51 | |
52 | pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { |
53 | if (copy(other) != status::success) is_initialized_ = false; |
54 | } |
55 | |
56 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:" , jcp_.isa, "" ), |
57 | jit_avx2_1x1_convolution_fwd_t); |
58 | |
59 | status_t init(engine_t *engine) { |
60 | bool ok = true && is_fwd() |
61 | && set_default_alg_kind(alg_kind::convolution_direct) |
62 | && expect_data_types(data_type::f32, data_type::f32, |
63 | data_type::f32, data_type::f32, data_type::f32) |
64 | && attr()->has_default_values( |
65 | primitive_attr_t::skip_mask_t::post_ops, |
66 | data_type::f32) |
67 | && !has_zero_dim_memory() && set_default_formats() |
68 | && attr_.set_default_formats(dst_md(0)) == status::success; |
69 | if (!ok) return status::unimplemented; |
70 | |
71 | const convolution_desc_t *conv_d = desc(); |
72 | const memory_desc_t *src_d = src_md(); |
73 | rtus_prepare(this, conv_d, src_d, dst_md(), weights_md()); |
74 | |
75 | CHECK(jit_avx2_1x1_conv_kernel_f32::init_conf( |
76 | jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr())); |
77 | if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); |
78 | |
79 | auto scratchpad = scratchpad_registry().registrar(); |
80 | jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); |
81 | |
82 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
83 | |
84 | return status::success; |
85 | } |
86 | |
87 | const memory_desc_t *dst_md(int index = 0) const override { |
88 | return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_; |
89 | } |
90 | |
91 | const memory_desc_t *arg_md(int index = 0) const override { |
92 | if (jcp_.with_dw_conv) { |
93 | switch (index) { |
94 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: |
95 | return dw_conv_pd_->weights_md(0); |
96 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: |
97 | return dw_conv_pd_->weights_md(1); |
98 | default: break; |
99 | } |
100 | } |
101 | return convolution_fwd_pd_t::arg_md(index); |
102 | } |
103 | |
104 | arg_usage_t arg_usage(int arg) const override { |
105 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) |
106 | return arg_usage_t::input; |
107 | |
108 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) |
109 | && attr_post_op_dw_inputs() > 1) |
110 | return arg_usage_t::input; |
111 | |
112 | return convolution_fwd_pd_t::arg_usage(arg); |
113 | } |
114 | |
115 | jit_1x1_conv_conf_t jcp_; |
116 | reduce_to_unit_stride_t rtus_; |
117 | jit_conv_conf_t *jcp_dw_; |
118 | std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_; |
119 | |
120 | protected: |
121 | template <cpu_isa_t isa> |
122 | using dw_pd_t = typename jit_uni_dw_convolution_fwd_t<isa, |
123 | data_type::f32>::pd_t; |
124 | |
125 | bool set_default_formats() { |
126 | using namespace format_tag; |
127 | |
128 | const memory_desc_wrapper src_d(&src_md_); |
129 | const memory_desc_wrapper dst_d(&dst_md_); |
130 | |
131 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
132 | const auto dat_tag_nCx8c |
133 | = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); |
134 | const auto curr_src_tag |
135 | = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
136 | const auto curr_dst_tag |
137 | = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
138 | const auto is_data_layout_nxc |
139 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
140 | src_d.format_kind() == format_kind::any) |
141 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
142 | dst_d.format_kind() == format_kind::any) |
143 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
144 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; |
145 | auto wei_tag = with_groups() |
146 | ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o) |
147 | : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o); |
148 | |
149 | return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
150 | } |
151 | |
152 | status_t copy(const pd_t &other) { |
153 | jcp_ = other.jcp_; |
154 | rtus_ = other.rtus_; |
155 | jcp_dw_ = nullptr; |
156 | if (other.dw_conv_pd_) { |
157 | dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>( |
158 | other.dw_conv_pd_->clone())); |
159 | if (!dw_conv_pd_) return status::out_of_memory; |
160 | if (jcp_.isa == avx2) { |
161 | jcp_dw_ = &(static_cast<dw_pd_t<avx2> *>(dw_conv_pd_.get()) |
162 | ->jcp_); |
163 | } else { // sse41 |
164 | jcp_dw_ = &(static_cast<dw_pd_t<sse41> *>(dw_conv_pd_.get()) |
165 | ->jcp_); |
166 | } |
167 | } |
168 | |
169 | return status::success; |
170 | } |
171 | |
172 | status_t depthwise_po_init(engine_t *engine) { |
173 | |
174 | using namespace memory_tracking; |
175 | auto &jcp_1x1 = jcp_; |
176 | primitive_attr_t attr_1x1(*attr()); |
177 | if (!attr_1x1.is_initialized()) return status::out_of_memory; |
178 | jit_conv_conf_t *jcp_dw = nullptr; |
179 | |
180 | const auto &src_md = dst_md_; |
181 | const memory_desc_wrapper src_d(src_md); |
182 | const auto nthr = dnnl_get_max_threads(); |
183 | auto l2_cache = platform::get_per_core_cache_size(2) * nthr; |
184 | |
185 | // Note: A robust fusion implementation would be to check if both |
186 | // 1x1 conv and dw conv that are considered here for fusion are |
187 | // optimal independently. This would require creating a new |
188 | // primitive_desc through primitive_iterator & check if they match. |
189 | // Due to concern that these creations and/or checks could be heavy, |
190 | // for 1x1: Check that no better ISA is available. |
191 | // for dw: Always fuse with same ISA. |
192 | // Caveat: May be a better dw conv exists. |
193 | |
194 | bool ok = true && (!mayiuse(avx512_core)) |
195 | && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) |
196 | // TODO: Below may be further tuned. |
197 | && (l2_cache * 2 < src_d.size()) |
198 | // load_grp_count check can be redundant due to l2 check |
199 | // above. Adding it explicitly as the current driver doesn't |
200 | // work if this condition fails. |
201 | && (jcp_1x1.load_grp_count < 2); |
202 | if (!ok) return status::unimplemented; |
203 | |
204 | int dw_po_index |
205 | = attr_1x1.post_ops_.find(primitive_kind::convolution); |
206 | |
207 | convolution_desc_t cd_dw; |
208 | primitive_attr_t attr_dw; |
209 | |
210 | CHECK(get_depthwise_conv_desc( |
211 | cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); |
212 | |
213 | if (jcp_1x1.isa == avx2) { |
214 | std::unique_ptr<dw_pd_t<avx2>> fusable_pd( |
215 | new dw_pd_t<avx2>(&cd_dw, &attr_dw, nullptr)); |
216 | CHECK(fusable_pd->init(engine)); |
217 | jcp_dw = &(fusable_pd->jcp_); |
218 | dw_conv_pd_ = std::move(fusable_pd); |
219 | } else { |
220 | // Special case for this primitive, as we dont have dw<avx>. |
221 | // In this case fuse with sse41 depthwise conv |
222 | // NOTE: Currently dw f32 kernel is similar for all ISA and can |
223 | // be fused regardless of ISA if inter-connecting md_ matches. |
224 | std::unique_ptr<dw_pd_t<sse41>> fusable_pd( |
225 | new dw_pd_t<sse41>(&cd_dw, &attr_dw, nullptr)); |
226 | CHECK(fusable_pd->init(engine)); |
227 | jcp_dw = &(fusable_pd->jcp_); |
228 | dw_conv_pd_ = std::move(fusable_pd); |
229 | } |
230 | |
231 | ok = true |
232 | && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) |
233 | && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) |
234 | && IMPLICATION( |
235 | jcp_dw->ow_block, jcp_dw->ow_block == jcp_dw->ow); |
236 | if (!ok) return status::unimplemented; |
237 | |
238 | assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); |
239 | assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); |
240 | assert(IMPLICATION( |
241 | dw_conv_pd_->weights_md(1)->data_type != data_type::undef, |
242 | dw_conv_pd_->weights_md(1)->format_kind |
243 | != format_kind::any)); |
244 | |
245 | jcp_dw->is_fused_conv = true; |
246 | // TODO: Support/experiment arbitary oc_work in dw conv. |
247 | // Until then we keep oc_work perfectly divisible. |
248 | while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) |
249 | --jcp_1x1.nb_load_blocking; |
250 | jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; |
251 | |
252 | while (jcp_1x1.nb_load_blocking % jcp_dw->nb_ch_blocking != 0) |
253 | --jcp_dw->nb_ch_blocking; |
254 | |
255 | jcp_dw->dw_conv_buffer_oc |
256 | = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; |
257 | |
258 | const auto dat_tag_nxc = utils::pick(ndims() - 3, format_tag::nwc, |
259 | format_tag::nhwc, format_tag::ndhwc); |
260 | const bool is_data_nxc = utils::everyone_is( |
261 | dat_tag_nxc, jcp_1x1.src_tag, jcp_1x1.dst_tag); |
262 | if (!is_data_nxc) |
263 | jcp_1x1.bcast_loop_output_step = jcp_1x1.ur * jcp_1x1.load_block |
264 | * jcp_1x1.typesize_out; |
265 | |
266 | registrar_t scratchpad(scratchpad_registry_); |
267 | registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); |
268 | |
269 | size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw->kh * jcp_dw->iw |
270 | * jcp_dw->dw_conv_buffer_oc; |
271 | assert(dw_conv_buffer_size_); |
272 | dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, |
273 | dw_conv_buffer_size_, |
274 | types::data_type_size(dw_conv_pd_->src_md()->data_type)); |
275 | |
276 | if (jcp_1x1.isa == avx2) |
277 | dw_conv_kernel_t<avx2>::init_scratchpad(dw_scratchpad, *jcp_dw); |
278 | else |
279 | dw_conv_kernel_t<sse41>::init_scratchpad( |
280 | dw_scratchpad, *jcp_dw); |
281 | |
282 | return status::success; |
283 | } |
284 | }; |
285 | |
286 | template <cpu_isa_t isa, typename conv_t> |
287 | friend status_t init_rtus_driver(conv_t *self); |
288 | |
289 | jit_avx2_1x1_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
290 | |
291 | status_t init(engine_t *engine) override { |
292 | CHECK(safe_ptr_assign(kernel_, |
293 | new jit_avx2_1x1_conv_kernel_f32( |
294 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
295 | CHECK(kernel_->create_kernel()); |
296 | CHECK(init_rtus_driver<avx2>(this)); |
297 | if (pd()->jcp_.with_dw_conv) { |
298 | auto &isa = pd()->jcp_.isa; |
299 | |
300 | if (isa == avx2) { |
301 | CHECK(safe_ptr_assign(kernel_dw_avx2, |
302 | new dw_conv_kernel_t<avx2>( |
303 | *(pd()->jcp_dw_), *pd()->dst_md(0)))); |
304 | CHECK(kernel_dw_avx2->create_kernel()); |
305 | } else { |
306 | CHECK(safe_ptr_assign(kernel_dw_sse41, |
307 | new dw_conv_kernel_t<sse41>( |
308 | *(pd()->jcp_dw_), *pd()->dst_md(0)))); |
309 | CHECK(kernel_dw_sse41->create_kernel()); |
310 | } |
311 | } |
312 | |
313 | return status::success; |
314 | } |
315 | |
316 | typedef typename prec_traits<data_type::f32>::type data_t; |
317 | |
318 | status_t execute(const exec_ctx_t &ctx) const override { |
319 | execute_forward(ctx); |
320 | return status::success; |
321 | } |
322 | |
323 | private: |
324 | void execute_forward(const exec_ctx_t &ctx) const; |
325 | void execute_forward_thr(const int ithr, const int nthr, const data_t *src, |
326 | const data_t *weights, const data_t *bias, const data_t *weights_dw, |
327 | const data_t *bias_dw, data_t *dst, |
328 | const memory_tracking::grantor_t &scratchpad, |
329 | const void *post_ops_binary_rhs_arg_vec, |
330 | const void *post_ops_binary_rhs_arg_vec_dw) const; |
331 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
332 | |
333 | std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_; |
334 | std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_; |
335 | |
336 | template <cpu_isa_t isa> |
337 | using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel<isa, data_type::f32>; |
338 | |
339 | std::unique_ptr<dw_conv_kernel_t<avx2>> kernel_dw_avx2; |
340 | std::unique_ptr<dw_conv_kernel_t<sse41>> kernel_dw_sse41; |
341 | }; |
342 | |
343 | struct jit_avx2_1x1_convolution_bwd_data_t : public primitive_t { |
344 | struct pd_t : public cpu_convolution_bwd_data_pd_t { |
345 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
346 | const convolution_fwd_pd_t *hint_fwd_pd) |
347 | : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) |
348 | , jcp_() |
349 | , rtus_() {} |
350 | |
351 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:" , avx2, "" ), |
352 | jit_avx2_1x1_convolution_bwd_data_t); |
353 | |
354 | status_t init(engine_t *engine) { |
355 | bool ok = true && desc()->prop_kind == prop_kind::backward_data |
356 | && set_default_alg_kind(alg_kind::convolution_direct) |
357 | && expect_data_types(data_type::f32, data_type::f32, |
358 | data_type::undef, data_type::f32, data_type::f32) |
359 | && attr()->has_default_values() && !has_zero_dim_memory() |
360 | && set_default_formats(); |
361 | if (!ok) return status::unimplemented; |
362 | |
363 | const convolution_desc_t *conv_d = desc(); |
364 | const memory_desc_t *diff_src_d = diff_src_md(); |
365 | rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md()); |
366 | |
367 | status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, |
368 | *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), |
369 | *attr()); |
370 | if (status != status::success) return status; |
371 | |
372 | auto scratchpad = scratchpad_registry().registrar(); |
373 | jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); |
374 | |
375 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
376 | |
377 | return status::success; |
378 | } |
379 | |
380 | jit_1x1_conv_conf_t jcp_; |
381 | reduce_to_unit_stride_t rtus_; |
382 | |
383 | protected: |
384 | bool set_default_formats() { |
385 | using namespace format_tag; |
386 | |
387 | const memory_desc_wrapper diff_src_d(&diff_src_md_); |
388 | const memory_desc_wrapper diff_dst_d(&diff_dst_md_); |
389 | |
390 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
391 | const auto dat_tag_nCx8c |
392 | = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); |
393 | const auto curr_src_tag |
394 | = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
395 | const auto curr_dst_tag |
396 | = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
397 | const auto is_data_layout_nxc |
398 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
399 | diff_src_d.format_kind() == format_kind::any) |
400 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
401 | diff_dst_d.format_kind() == format_kind::any) |
402 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
403 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; |
404 | auto wei_tag = with_groups() |
405 | ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i) |
406 | : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i); |
407 | |
408 | return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
409 | } |
410 | }; |
411 | |
412 | template <cpu_isa_t isa, typename conv_t> |
413 | friend status_t init_rtus_driver(conv_t *self); |
414 | |
415 | jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
416 | |
417 | typedef typename prec_traits<data_type::f32>::type data_t; |
418 | |
419 | status_t init(engine_t *engine) override { |
420 | CHECK(safe_ptr_assign(kernel_, |
421 | new jit_avx2_1x1_conv_kernel_f32( |
422 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
423 | CHECK(kernel_->create_kernel()); |
424 | CHECK(init_rtus_driver<avx2>(this)); |
425 | return status::success; |
426 | } |
427 | |
428 | status_t execute(const exec_ctx_t &ctx) const override { |
429 | execute_backward_data(ctx); |
430 | return status::success; |
431 | } |
432 | |
433 | private: |
434 | void execute_backward_data(const exec_ctx_t &ctx) const; |
435 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
436 | |
437 | std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_; |
438 | std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_; |
439 | }; |
440 | |
441 | struct jit_avx2_1x1_convolution_bwd_weights_t : public primitive_t { |
442 | struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
443 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
444 | const convolution_fwd_pd_t *hint_fwd_pd) |
445 | : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) |
446 | , jcp_() |
447 | , rtus_() {} |
448 | |
449 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:" , avx2, "" ), |
450 | jit_avx2_1x1_convolution_bwd_weights_t); |
451 | |
452 | status_t init(engine_t *engine) { |
453 | bool ok = true && desc()->prop_kind == prop_kind::backward_weights |
454 | && set_default_alg_kind(alg_kind::convolution_direct) |
455 | && expect_data_types(data_type::f32, data_type::f32, |
456 | data_type::f32, data_type::f32, data_type::f32) |
457 | && attr()->has_default_values() && !has_zero_dim_memory() |
458 | && set_default_formats(); |
459 | if (!ok) return status::unimplemented; |
460 | |
461 | const convolution_desc_t *conv_d = desc(); |
462 | const memory_desc_t *src_d = src_md(); |
463 | rtus_prepare(this, conv_d, src_d, diff_dst_md(), diff_weights_md()); |
464 | |
465 | status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, |
466 | *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), |
467 | *attr()); |
468 | if (status != status::success) return status; |
469 | |
470 | init_balancers(); |
471 | |
472 | auto scratchpad = scratchpad_registry().registrar(); |
473 | jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); |
474 | |
475 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
476 | |
477 | auto reducer_bia_scratchpad = memory_tracking::registrar_t( |
478 | scratchpad, memory_tracking::names::prefix_reducer_bia); |
479 | reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); |
480 | |
481 | auto reducer_wei_scratchpad = memory_tracking::registrar_t( |
482 | scratchpad, memory_tracking::names::prefix_reducer_wei); |
483 | reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); |
484 | |
485 | return status::success; |
486 | } |
487 | |
488 | jit_1x1_conv_conf_t jcp_; |
489 | cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_; |
490 | cpu_reducer_2d_t<data_type::f32>::conf_t reducer_wei_conf_; |
491 | reduce_to_unit_stride_t rtus_; |
492 | |
493 | protected: |
494 | bool set_default_formats() { |
495 | using namespace format_tag; |
496 | |
497 | const memory_desc_wrapper src_d(&src_md_); |
498 | const memory_desc_wrapper diff_dst_d(&diff_dst_md_); |
499 | |
500 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
501 | const auto dat_tag_nCx8c |
502 | = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); |
503 | const auto curr_src_tag |
504 | = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
505 | const auto curr_dst_tag |
506 | = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
507 | const auto is_data_layout_nxc |
508 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
509 | src_d.format_kind() == format_kind::any) |
510 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
511 | diff_dst_d.format_kind() == format_kind::any) |
512 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
513 | |
514 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; |
515 | auto wei_tag = with_groups() |
516 | ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o) |
517 | : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o); |
518 | |
519 | return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
520 | } |
521 | |
522 | private: |
523 | void init_balancers() { |
524 | const int ic_block = jcp_.bcast_block; |
525 | const int nb_ic = jcp_.nb_bcast; |
526 | const int nb_ic_blocking = jcp_.nb_bcast_blocking; |
527 | const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking); |
528 | |
529 | const int oc_block = jcp_.load_block; |
530 | const int nb_oc = jcp_.nb_load; |
531 | const int nb_oc_blocking = jcp_.nb_load_blocking; |
532 | const int load_work = utils::div_up(nb_oc, nb_oc_blocking); |
533 | |
534 | const int job_size |
535 | = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block; |
536 | const int njobs_x = bcast_work; |
537 | const int njobs_y = jcp_.ngroups * load_work; |
538 | |
539 | const int max_threads = dnnl_get_max_threads(); |
540 | const size_t max_buffer_size = (size_t)max_threads * job_size * 8; |
541 | |
542 | if (with_bias()) { |
543 | reducer_bia_conf_.init(reduce_balancer_t(max_threads, oc_block, |
544 | jcp_.ngroups * nb_oc, jcp_.mb, max_buffer_size, true)); |
545 | } |
546 | |
547 | reducer_wei_conf_.init( |
548 | reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x, |
549 | jcp_.mb * jcp_.nb_reduce, max_buffer_size, true), |
550 | job_size / nb_oc_blocking, nb_oc_blocking, ic_block, |
551 | nb_ic * ic_block * oc_block, nb_oc); |
552 | } |
553 | }; |
554 | |
555 | template <cpu_isa_t isa, typename conv_t> |
556 | friend status_t init_rtus_driver(conv_t *self); |
557 | |
558 | jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd) |
559 | : primitive_t(apd) {} |
560 | |
561 | typedef typename prec_traits<data_type::f32>::type data_t; |
562 | |
563 | status_t init(engine_t *engine) override; |
564 | |
565 | status_t execute(const exec_ctx_t &ctx) const override { |
566 | execute_backward_weights(ctx); |
567 | return status::success; |
568 | } |
569 | |
570 | private: |
571 | void execute_backward_weights(const exec_ctx_t &ctx) const; |
572 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
573 | |
574 | std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_; |
575 | std::unique_ptr<cpu_reducer_2d_t<data_type::f32>> reducer_weights_; |
576 | std::unique_ptr<cpu_reducer_t<data_type::f32>> reducer_bias_; |
577 | std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_; |
578 | }; |
579 | |
580 | } // namespace x64 |
581 | } // namespace cpu |
582 | } // namespace impl |
583 | } // namespace dnnl |
584 | |
585 | #endif |
586 | |