1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_SSE41_1X1_CONVOLUTION_HPP
18#define CPU_X64_JIT_SSE41_1X1_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/primitive_hashing.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_convolution_pd.hpp"
28#include "cpu/dw_convolution_utils.hpp"
29#include "cpu/platform.hpp"
30
31#include "cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp"
32#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
33#include "cpu/x64/jit_uni_dw_convolution.hpp"
34
35namespace dnnl {
36namespace impl {
37namespace cpu {
38namespace x64 {
39
40struct jit_sse41_1x1_convolution_fwd_t : public primitive_t {
41 struct pd_t : public cpu_convolution_fwd_pd_t {
42 using dw_conv_pd_type = jit_sse41_dw_convolution_fwd_t::pd_t;
43 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
44 const typename pd_t::base_class *hint_fwd_pd)
45 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
46
47 pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) {
48 if (copy(other) != status::success) is_initialized_ = false;
49 }
50
51 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", sse41, ""),
52 jit_sse41_1x1_convolution_fwd_t);
53
54 status_t init(engine_t *engine) {
55 using namespace data_type;
56 bool ok = is_fwd()
57 && set_default_alg_kind(alg_kind::convolution_direct)
58 && expect_data_types(f32, f32, f32, f32, f32)
59 && attr()->has_default_values(
60 primitive_attr_t::skip_mask_t::post_ops, f32)
61 && !has_zero_dim_memory() && set_default_formats()
62 && attr_.set_default_formats(dst_md(0)) == status::success;
63 if (!ok) return status::unimplemented;
64
65 CHECK(jit_sse41_1x1_conv_kernel_f32::init_conf(jcp_, *desc(),
66 *src_md(), *weights_md(), *dst_md(), *attr(),
67 dnnl_get_max_threads()));
68 if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine));
69
70 return status::success;
71 }
72
73 const memory_desc_t *dst_md(int index = 0) const override {
74 return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_;
75 }
76
77 const memory_desc_t *arg_md(int index = 0) const override {
78 if (jcp_.with_dw_conv) {
79 switch (index) {
80 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
81 return dw_conv_pd_->weights_md(0);
82 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
83 return dw_conv_pd_->weights_md(1);
84 default: break;
85 }
86 }
87 return convolution_fwd_pd_t::arg_md(index);
88 }
89
90 arg_usage_t arg_usage(int arg) const override {
91 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
92 return arg_usage_t::input;
93
94 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
95 && attr_post_op_dw_inputs() > 1)
96 return arg_usage_t::input;
97
98 return convolution_fwd_pd_t::arg_usage(arg);
99 }
100
101 jit_1x1_conv_conf_t jcp_;
102 using dw_pd_t = jit_sse41_dw_convolution_fwd_t::pd_t;
103 std::unique_ptr<dw_pd_t> dw_conv_pd_;
104
105 protected:
106 bool set_default_formats() {
107 using namespace format_tag;
108
109 const memory_desc_wrapper src_d(&src_md_);
110 const memory_desc_wrapper dst_d(&dst_md_);
111
112 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
113 const auto dat_tag_nCx8c
114 = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
115 const auto curr_src_tag
116 = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
117 const auto curr_dst_tag
118 = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
119 const auto is_data_layout_nxc
120 = IMPLICATION(curr_src_tag != dat_tag_nxc,
121 src_d.format_kind() == format_kind::any)
122 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
123 dst_d.format_kind() == format_kind::any)
124 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
125 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
126 auto wei_tag = with_groups()
127 ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o)
128 : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o);
129
130 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
131 }
132
133 status_t copy(const pd_t &other) {
134 jcp_ = other.jcp_;
135 if (other.dw_conv_pd_) {
136 dw_conv_pd_.reset(other.dw_conv_pd_->clone());
137 if (!dw_conv_pd_) return status::out_of_memory;
138 }
139 return status::success;
140 }
141
142 status_t depthwise_po_init(engine_t *engine) {
143
144 using namespace memory_tracking;
145 auto &jcp_1x1 = jcp_;
146 primitive_attr_t attr_1x1(*attr());
147 if (!attr_1x1.is_initialized()) return status::out_of_memory;
148 const auto &src_md = dst_md_;
149 const memory_desc_wrapper src_d(src_md);
150 const auto nthr = dnnl_get_max_threads();
151 auto l2_cache = platform::get_per_core_cache_size(2) * nthr;
152
153 // Note: A robust fusion implementation would be to check if both
154 // 1x1 conv and dw conv that are considered here for fusion are
155 // optimal independently. This would require creating a new
156 // primitive_desc through primitive_iterator & check if they match.
157 // Due to concern that these creations and/or checks could be heavy,
158 // for 1x1: Check that no better ISA is available.
159 // for dw: Always fuse with same ISA.
160 // Caveat: May be a better dw conv exists.
161
162 bool ok = true && (!mayiuse(avx))
163 && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1)
164 // TODO: Below may be further tuned.
165 && (l2_cache * 2 < src_d.size())
166 // load_grp_count check can be redundant due to l2 check
167 // above. Adding it explicitly as the current driver doesn't
168 // work if this condition fails.
169 && (jcp_1x1.load_grp_count < 2);
170 if (!ok) return status::unimplemented;
171
172 int dw_po_index
173 = attr_1x1.post_ops_.find(primitive_kind::convolution);
174
175 convolution_desc_t cd_dw;
176 primitive_attr_t attr_dw;
177
178 CHECK(get_depthwise_conv_desc(
179 cd_dw, src_md, attr_1x1, attr_dw, dw_po_index));
180
181 CHECK(safe_ptr_assign(
182 dw_conv_pd_, new dw_pd_t(&cd_dw, &attr_dw, nullptr)));
183 CHECK(dw_conv_pd_->init(engine));
184 auto &jcp_dw = dw_conv_pd_->jcp_;
185
186 ok = true
187 && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0)))
188 && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0)
189 && IMPLICATION(
190 jcp_dw.ow_block, jcp_dw.ow_block == jcp_dw.ow);
191 if (!ok) return status::unimplemented;
192
193 assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any);
194 assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any);
195 assert(IMPLICATION(
196 dw_conv_pd_->weights_md(1)->data_type != data_type::undef,
197 dw_conv_pd_->weights_md(1)->format_kind
198 != format_kind::any));
199
200 jcp_dw.is_fused_conv = true;
201 // TODO: Support/experiment arbitary oc_work in dw conv.
202 // Until then we keep oc_work perfectly divisible.
203 while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0)
204 --jcp_1x1.nb_load_blocking;
205 jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking;
206
207 while (jcp_1x1.nb_load_blocking % jcp_dw.nb_ch_blocking != 0)
208 --jcp_dw.nb_ch_blocking;
209
210 jcp_dw.dw_conv_buffer_oc
211 = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block;
212
213 const auto dat_tag_nxc = utils::pick(ndims() - 3, format_tag::nwc,
214 format_tag::nhwc, format_tag::ndhwc);
215 const bool is_data_nxc = utils::everyone_is(
216 dat_tag_nxc, jcp_1x1.src_tag, jcp_1x1.dst_tag);
217 if (!is_data_nxc)
218 jcp_1x1.bcast_loop_output_step = jcp_1x1.ur * jcp_1x1.load_block
219 * jcp_1x1.typesize_out;
220
221 registrar_t scratchpad(scratchpad_registry_);
222 registrar_t dw_scratchpad(scratchpad, names::prefix_fusion);
223
224 size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw.kh * jcp_dw.iw
225 * jcp_dw.dw_conv_buffer_oc;
226 assert(dw_conv_buffer_size_);
227 dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
228 dw_conv_buffer_size_,
229 types::data_type_size(dw_conv_pd_->src_md()->data_type));
230
231 jit_uni_dw_conv_fwd_kernel<sse41, data_type::f32>::init_scratchpad(
232 dw_scratchpad, jcp_dw);
233
234 return status::success;
235 }
236 };
237
238 jit_sse41_1x1_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
239
240 typedef typename prec_traits<data_type::f32>::type data_t;
241
242 status_t init(engine_t *engine) override {
243 CHECK(safe_ptr_assign(kernel_,
244 new jit_sse41_1x1_conv_kernel_f32(
245 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
246 CHECK(kernel_->create_kernel());
247 if (pd()->jcp_.with_dw_conv) {
248 CHECK(safe_ptr_assign(kernel_dw_,
249 new dw_conv_kernel_t(
250 pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0))));
251 return kernel_dw_->create_kernel();
252 }
253
254 return status::success;
255 }
256
257 status_t execute(const exec_ctx_t &ctx) const override {
258 execute_forward(ctx);
259 return status::success;
260 }
261
262private:
263 void execute_forward(const exec_ctx_t &ctx) const;
264 void execute_forward_thr(const int ithr, const int nthr, const data_t *src,
265 const data_t *weights, const data_t *bias, const data_t *weights_dw,
266 const data_t *bias_dw, data_t *dst,
267 const memory_tracking::grantor_t &scratchpad,
268 const void *post_ops_binary_rhs_arg_vec,
269 const void *post_ops_binary_rhs_arg_vec_dw) const;
270 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
271 std::unique_ptr<jit_sse41_1x1_conv_kernel_f32> kernel_;
272 using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel_f32<sse41>;
273 std::unique_ptr<dw_conv_kernel_t> kernel_dw_;
274};
275
276} // namespace x64
277} // namespace cpu
278} // namespace impl
279} // namespace dnnl
280
281#endif
282