1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_SSE41_1X1_CONVOLUTION_HPP |
18 | #define CPU_X64_JIT_SSE41_1X1_CONVOLUTION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/primitive_hashing.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | #include "cpu/cpu_convolution_pd.hpp" |
28 | #include "cpu/dw_convolution_utils.hpp" |
29 | #include "cpu/platform.hpp" |
30 | |
31 | #include "cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp" |
32 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
33 | #include "cpu/x64/jit_uni_dw_convolution.hpp" |
34 | |
35 | namespace dnnl { |
36 | namespace impl { |
37 | namespace cpu { |
38 | namespace x64 { |
39 | |
40 | struct jit_sse41_1x1_convolution_fwd_t : public primitive_t { |
41 | struct pd_t : public cpu_convolution_fwd_pd_t { |
42 | using dw_conv_pd_type = jit_sse41_dw_convolution_fwd_t::pd_t; |
43 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
44 | const typename pd_t::base_class *hint_fwd_pd) |
45 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} |
46 | |
47 | pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { |
48 | if (copy(other) != status::success) is_initialized_ = false; |
49 | } |
50 | |
51 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:" , sse41, "" ), |
52 | jit_sse41_1x1_convolution_fwd_t); |
53 | |
54 | status_t init(engine_t *engine) { |
55 | using namespace data_type; |
56 | bool ok = is_fwd() |
57 | && set_default_alg_kind(alg_kind::convolution_direct) |
58 | && expect_data_types(f32, f32, f32, f32, f32) |
59 | && attr()->has_default_values( |
60 | primitive_attr_t::skip_mask_t::post_ops, f32) |
61 | && !has_zero_dim_memory() && set_default_formats() |
62 | && attr_.set_default_formats(dst_md(0)) == status::success; |
63 | if (!ok) return status::unimplemented; |
64 | |
65 | CHECK(jit_sse41_1x1_conv_kernel_f32::init_conf(jcp_, *desc(), |
66 | *src_md(), *weights_md(), *dst_md(), *attr(), |
67 | dnnl_get_max_threads())); |
68 | if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); |
69 | |
70 | return status::success; |
71 | } |
72 | |
73 | const memory_desc_t *dst_md(int index = 0) const override { |
74 | return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_; |
75 | } |
76 | |
77 | const memory_desc_t *arg_md(int index = 0) const override { |
78 | if (jcp_.with_dw_conv) { |
79 | switch (index) { |
80 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: |
81 | return dw_conv_pd_->weights_md(0); |
82 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: |
83 | return dw_conv_pd_->weights_md(1); |
84 | default: break; |
85 | } |
86 | } |
87 | return convolution_fwd_pd_t::arg_md(index); |
88 | } |
89 | |
90 | arg_usage_t arg_usage(int arg) const override { |
91 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) |
92 | return arg_usage_t::input; |
93 | |
94 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) |
95 | && attr_post_op_dw_inputs() > 1) |
96 | return arg_usage_t::input; |
97 | |
98 | return convolution_fwd_pd_t::arg_usage(arg); |
99 | } |
100 | |
101 | jit_1x1_conv_conf_t jcp_; |
102 | using dw_pd_t = jit_sse41_dw_convolution_fwd_t::pd_t; |
103 | std::unique_ptr<dw_pd_t> dw_conv_pd_; |
104 | |
105 | protected: |
106 | bool set_default_formats() { |
107 | using namespace format_tag; |
108 | |
109 | const memory_desc_wrapper src_d(&src_md_); |
110 | const memory_desc_wrapper dst_d(&dst_md_); |
111 | |
112 | const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); |
113 | const auto dat_tag_nCx8c |
114 | = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); |
115 | const auto curr_src_tag |
116 | = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
117 | const auto curr_dst_tag |
118 | = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c); |
119 | const auto is_data_layout_nxc |
120 | = IMPLICATION(curr_src_tag != dat_tag_nxc, |
121 | src_d.format_kind() == format_kind::any) |
122 | && IMPLICATION(curr_dst_tag != dat_tag_nxc, |
123 | dst_d.format_kind() == format_kind::any) |
124 | && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); |
125 | auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; |
126 | auto wei_tag = with_groups() |
127 | ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) |
128 | : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); |
129 | |
130 | return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
131 | } |
132 | |
133 | status_t copy(const pd_t &other) { |
134 | jcp_ = other.jcp_; |
135 | if (other.dw_conv_pd_) { |
136 | dw_conv_pd_.reset(other.dw_conv_pd_->clone()); |
137 | if (!dw_conv_pd_) return status::out_of_memory; |
138 | } |
139 | return status::success; |
140 | } |
141 | |
142 | status_t depthwise_po_init(engine_t *engine) { |
143 | |
144 | using namespace memory_tracking; |
145 | auto &jcp_1x1 = jcp_; |
146 | primitive_attr_t attr_1x1(*attr()); |
147 | if (!attr_1x1.is_initialized()) return status::out_of_memory; |
148 | const auto &src_md = dst_md_; |
149 | const memory_desc_wrapper src_d(src_md); |
150 | const auto nthr = dnnl_get_max_threads(); |
151 | auto l2_cache = platform::get_per_core_cache_size(2) * nthr; |
152 | |
153 | // Note: A robust fusion implementation would be to check if both |
154 | // 1x1 conv and dw conv that are considered here for fusion are |
155 | // optimal independently. This would require creating a new |
156 | // primitive_desc through primitive_iterator & check if they match. |
157 | // Due to concern that these creations and/or checks could be heavy, |
158 | // for 1x1: Check that no better ISA is available. |
159 | // for dw: Always fuse with same ISA. |
160 | // Caveat: May be a better dw conv exists. |
161 | |
162 | bool ok = true && (!mayiuse(avx)) |
163 | && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) |
164 | // TODO: Below may be further tuned. |
165 | && (l2_cache * 2 < src_d.size()) |
166 | // load_grp_count check can be redundant due to l2 check |
167 | // above. Adding it explicitly as the current driver doesn't |
168 | // work if this condition fails. |
169 | && (jcp_1x1.load_grp_count < 2); |
170 | if (!ok) return status::unimplemented; |
171 | |
172 | int dw_po_index |
173 | = attr_1x1.post_ops_.find(primitive_kind::convolution); |
174 | |
175 | convolution_desc_t cd_dw; |
176 | primitive_attr_t attr_dw; |
177 | |
178 | CHECK(get_depthwise_conv_desc( |
179 | cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); |
180 | |
181 | CHECK(safe_ptr_assign( |
182 | dw_conv_pd_, new dw_pd_t(&cd_dw, &attr_dw, nullptr))); |
183 | CHECK(dw_conv_pd_->init(engine)); |
184 | auto &jcp_dw = dw_conv_pd_->jcp_; |
185 | |
186 | ok = true |
187 | && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) |
188 | && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) |
189 | && IMPLICATION( |
190 | jcp_dw.ow_block, jcp_dw.ow_block == jcp_dw.ow); |
191 | if (!ok) return status::unimplemented; |
192 | |
193 | assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); |
194 | assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); |
195 | assert(IMPLICATION( |
196 | dw_conv_pd_->weights_md(1)->data_type != data_type::undef, |
197 | dw_conv_pd_->weights_md(1)->format_kind |
198 | != format_kind::any)); |
199 | |
200 | jcp_dw.is_fused_conv = true; |
201 | // TODO: Support/experiment arbitary oc_work in dw conv. |
202 | // Until then we keep oc_work perfectly divisible. |
203 | while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) |
204 | --jcp_1x1.nb_load_blocking; |
205 | jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; |
206 | |
207 | while (jcp_1x1.nb_load_blocking % jcp_dw.nb_ch_blocking != 0) |
208 | --jcp_dw.nb_ch_blocking; |
209 | |
210 | jcp_dw.dw_conv_buffer_oc |
211 | = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; |
212 | |
213 | const auto dat_tag_nxc = utils::pick(ndims() - 3, format_tag::nwc, |
214 | format_tag::nhwc, format_tag::ndhwc); |
215 | const bool is_data_nxc = utils::everyone_is( |
216 | dat_tag_nxc, jcp_1x1.src_tag, jcp_1x1.dst_tag); |
217 | if (!is_data_nxc) |
218 | jcp_1x1.bcast_loop_output_step = jcp_1x1.ur * jcp_1x1.load_block |
219 | * jcp_1x1.typesize_out; |
220 | |
221 | registrar_t scratchpad(scratchpad_registry_); |
222 | registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); |
223 | |
224 | size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw.kh * jcp_dw.iw |
225 | * jcp_dw.dw_conv_buffer_oc; |
226 | assert(dw_conv_buffer_size_); |
227 | dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, |
228 | dw_conv_buffer_size_, |
229 | types::data_type_size(dw_conv_pd_->src_md()->data_type)); |
230 | |
231 | jit_uni_dw_conv_fwd_kernel<sse41, data_type::f32>::init_scratchpad( |
232 | dw_scratchpad, jcp_dw); |
233 | |
234 | return status::success; |
235 | } |
236 | }; |
237 | |
238 | jit_sse41_1x1_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
239 | |
240 | typedef typename prec_traits<data_type::f32>::type data_t; |
241 | |
242 | status_t init(engine_t *engine) override { |
243 | CHECK(safe_ptr_assign(kernel_, |
244 | new jit_sse41_1x1_conv_kernel_f32( |
245 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
246 | CHECK(kernel_->create_kernel()); |
247 | if (pd()->jcp_.with_dw_conv) { |
248 | CHECK(safe_ptr_assign(kernel_dw_, |
249 | new dw_conv_kernel_t( |
250 | pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0)))); |
251 | return kernel_dw_->create_kernel(); |
252 | } |
253 | |
254 | return status::success; |
255 | } |
256 | |
257 | status_t execute(const exec_ctx_t &ctx) const override { |
258 | execute_forward(ctx); |
259 | return status::success; |
260 | } |
261 | |
262 | private: |
263 | void execute_forward(const exec_ctx_t &ctx) const; |
264 | void execute_forward_thr(const int ithr, const int nthr, const data_t *src, |
265 | const data_t *weights, const data_t *bias, const data_t *weights_dw, |
266 | const data_t *bias_dw, data_t *dst, |
267 | const memory_tracking::grantor_t &scratchpad, |
268 | const void *post_ops_binary_rhs_arg_vec, |
269 | const void *post_ops_binary_rhs_arg_vec_dw) const; |
270 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
271 | std::unique_ptr<jit_sse41_1x1_conv_kernel_f32> kernel_; |
272 | using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel_f32<sse41>; |
273 | std::unique_ptr<dw_conv_kernel_t> kernel_dw_; |
274 | }; |
275 | |
276 | } // namespace x64 |
277 | } // namespace cpu |
278 | } // namespace impl |
279 | } // namespace dnnl |
280 | |
281 | #endif |
282 | |