1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_X8S8S32X_1X1_CONVOLUTION_HPP |
18 | #define CPU_X64_JIT_UNI_X8S8S32X_1X1_CONVOLUTION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/primitive_hashing.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | #include "cpu/cpu_convolution_pd.hpp" |
28 | #include "cpu/dw_convolution_utils.hpp" |
29 | #include "cpu/platform.hpp" |
30 | |
31 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
32 | #include "cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp" |
33 | #include "cpu/x64/jit_uni_x8s8s32x_convolution.hpp" |
34 | |
35 | namespace dnnl { |
36 | namespace impl { |
37 | namespace cpu { |
38 | namespace x64 { |
39 | |
40 | template <cpu_isa_t isa> |
41 | struct jit_uni_x8s8s32x_1x1_convolution_fwd_t : public primitive_t { |
42 | struct pd_t : public cpu_convolution_fwd_pd_t { |
43 | using dw_conv_pd_type = cpu_convolution_fwd_pd_t; |
44 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
45 | const typename pd_t::base_class *hint_fwd_pd) |
46 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) |
47 | , jcp_() |
48 | , rtus_() |
49 | , jcp_dw_(nullptr) {} |
50 | |
51 | pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { |
52 | if (copy(other) != status::success) is_initialized_ = false; |
53 | } |
54 | |
55 | DECLARE_COMMON_PD_T( |
56 | JIT_IMPL_NAME_HELPER("jit_uni_int8_1x1:" , |
57 | isa == avx2 && jcp_.has_vnni ? avx2_vnni : isa, "" ), |
58 | jit_uni_x8s8s32x_1x1_convolution_fwd_t); |
59 | |
60 | status_t init(engine_t *engine) { |
61 | using namespace data_type; |
62 | using smask_t = primitive_attr_t::skip_mask_t; |
63 | bool ok = is_fwd() |
64 | && set_default_alg_kind(alg_kind::convolution_direct) |
65 | && utils::one_of(src_md(0)->data_type, s8, u8) |
66 | && weights_md(0)->data_type == s8 |
67 | && IMPLICATION(with_bias(), |
68 | utils::one_of( |
69 | weights_md(1)->data_type, f32, s32, s8, u8)) |
70 | && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8) |
71 | && desc()->accum_data_type == s32 |
72 | && attr()->has_default_values(smask_t::scales_runtime |
73 | | smask_t::zero_points_runtime |
74 | | smask_t::post_ops | smask_t::sum_dt, |
75 | dst_md(0)->data_type) |
76 | && attr()->scales_.has_default_values( |
77 | {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST, |
78 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS, |
79 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST}) |
80 | && attr()->post_ops_.check_sum_consistent_dt( |
81 | dst_md(0)->data_type) |
82 | && !has_zero_dim_memory() && zero_points_ok() |
83 | && set_default_formats_common( |
84 | dat_tag(), format_tag::any, dat_tag()) |
85 | && set_or_check_wei_format() |
86 | && attr_.set_default_formats(dst_md(0)) == status::success; |
87 | if (!ok) return status::unimplemented; |
88 | |
89 | const convolution_desc_t *conv_d = desc(); |
90 | const memory_desc_t *src_d = src_md(); |
91 | rtus_prepare(this, conv_d, src_d, dst_md(), weights_md()); |
92 | |
93 | CHECK(jit_uni_x8s8s32x_1x1_conv_kernel<isa>::init_conf(jcp_, |
94 | *conv_d, *src_d, *weights_md(), *dst_md(), |
95 | with_bias() ? *weights_md(1) : types::zero_md(), attr_, |
96 | dnnl_get_max_threads(), rtus_.reduce_src_)); |
97 | if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); |
98 | |
99 | auto scratchpad = scratchpad_registry().registrar(); |
100 | jit_uni_x8s8s32x_1x1_conv_kernel<isa>::init_scratchpad( |
101 | scratchpad, jcp_, *attr()); |
102 | |
103 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
104 | |
105 | return status::success; |
106 | } |
107 | |
108 | const memory_desc_t *dst_md(int index = 0) const override { |
109 | return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_; |
110 | } |
111 | |
112 | const memory_desc_t *arg_md(int index = 0) const override { |
113 | if (jcp_.with_dw_conv) { |
114 | switch (index) { |
115 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: |
116 | return dw_conv_pd_->weights_md(0); |
117 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: |
118 | return dw_conv_pd_->weights_md(1); |
119 | default: break; |
120 | } |
121 | } |
122 | return convolution_fwd_pd_t::arg_md(index); |
123 | } |
124 | |
125 | arg_usage_t arg_usage(int arg) const override { |
126 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) |
127 | return arg_usage_t::input; |
128 | |
129 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) |
130 | && attr_post_op_dw_inputs() > 1) |
131 | return arg_usage_t::input; |
132 | |
133 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_OUTPUT_SCALES) |
134 | && jcp_.with_dw_conv) |
135 | return arg_usage_t::input; |
136 | |
137 | return convolution_fwd_pd_t::arg_usage(arg); |
138 | } |
139 | |
140 | jit_1x1_conv_conf_t jcp_; |
141 | reduce_to_unit_stride_t rtus_; |
142 | jit_conv_conf_t *jcp_dw_; // doesn't own a resource |
143 | std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_; |
144 | using dw_pd_t = typename jit_uni_x8s8s32x_convolution_fwd_t<isa>::pd_t; |
145 | |
146 | protected: |
147 | bool zero_points_ok() const { |
148 | // Only common zero points are supported -> mask should only be 0 |
149 | int mask_src = 0, mask_dst = 0; |
150 | attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src); |
151 | attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst); |
152 | return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) |
153 | && mask_src == 0 && mask_dst == 0; |
154 | } |
155 | |
156 | status_t copy(const pd_t &other) { |
157 | jcp_ = other.jcp_; |
158 | rtus_ = other.rtus_; |
159 | jcp_dw_ = nullptr; |
160 | if (other.dw_conv_pd_) { |
161 | dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>( |
162 | other.dw_conv_pd_->clone())); |
163 | if (!dw_conv_pd_) return status::out_of_memory; |
164 | |
165 | jcp_dw_ = &(static_cast<dw_pd_t *>(dw_conv_pd_.get())->jcp_); |
166 | } |
167 | return status::success; |
168 | } |
169 | |
170 | format_tag_t dat_tag() const { |
171 | return utils::pick(ndims() - 3, format_tag::nwc, format_tag::nhwc, |
172 | format_tag::ndhwc); |
173 | } |
174 | |
175 | bool set_or_check_wei_format() { |
176 | using namespace format_tag; |
177 | using namespace memory_extra_flags; |
178 | const auto zp = attr()->zero_points_; |
179 | const int c_mask = 0x1, |
180 | g_mask = 0x3; // mask for i/o-channel and ngroups |
181 | |
182 | const bool is_src_s8 = src_md_.data_type == data_type::s8; |
183 | const bool is_src_zero_point = !zp.has_default_values(DNNL_ARG_SRC); |
184 | format_tag_t wei_tag; |
185 | switch (isa) { |
186 | case avx2: |
187 | wei_tag = with_groups() |
188 | ? utils::pick(ndims() - 3, gOIw2i8o4i, gOIhw2i8o4i, |
189 | gOIdhw2i8o4i) |
190 | : utils::pick(ndims() - 3, OIw2i8o4i, OIhw2i8o4i, |
191 | OIdhw2i8o4i); |
192 | break; |
193 | case sse41: |
194 | wei_tag = with_groups() ? utils::pick(ndims() - 3, gOIw4o4i, |
195 | gOIhw4o4i, gOIdhw4o4i) |
196 | : utils::pick(ndims() - 3, OIw4o4i, |
197 | OIhw4o4i, OIdhw4o4i); |
198 | break; |
199 | default: assert(!"Current ISA is not supported!" ); break; |
200 | } |
201 | |
202 | memory_desc_t want_wei_md = weights_md_; |
203 | memory_desc_init_by_tag(want_wei_md, wei_tag); |
204 | if (is_src_s8) { |
205 | want_wei_md.extra.flags |
206 | = 0 | compensation_conv_s8s8 | scale_adjust; |
207 | want_wei_md.extra.compensation_mask |
208 | = with_groups() ? g_mask : c_mask; |
209 | want_wei_md.extra.scale_adjust |
210 | = mayiuse(avx2_vnni) ? 1.0f : 0.5f; |
211 | } |
212 | if (is_src_zero_point) { |
213 | want_wei_md.extra.flags |= compensation_conv_asymmetric_src; |
214 | want_wei_md.extra.asymm_compensation_mask |
215 | = with_groups() ? g_mask : c_mask; |
216 | } |
217 | |
218 | if (weights_md_.format_kind == format_kind::any) { |
219 | weights_md_ = want_wei_md; |
220 | return true; |
221 | } |
222 | |
223 | return weights_md_ == want_wei_md; |
224 | } |
225 | |
226 | status_t depthwise_po_init(engine_t *engine) { |
227 | using namespace memory_tracking; |
228 | auto &jcp_1x1 = jcp_; |
229 | primitive_attr_t attr_1x1(*attr()); |
230 | if (!attr_1x1.is_initialized()) return status::out_of_memory; |
231 | |
232 | const auto &src_md = dst_md_; |
233 | const memory_desc_wrapper src_d(src_md); |
234 | const auto nthr = dnnl_get_max_threads(); |
235 | auto l2_cache = platform::get_per_core_cache_size(2) * nthr; |
236 | |
237 | // Note: A robust fusion implementation would be to check if both |
238 | // 1x1 conv and dw conv that are considered here for fusion are |
239 | // optimal independently. This would require creating a new |
240 | // primitive_desc through primitive_iterator & check if they match. |
241 | // Due to concern that these creations and/or checks could be heavy, |
242 | // for 1x1: Check that no better ISA is available. |
243 | // for dw: Always fuse with same ISA. |
244 | // Caveat: May be a better dw conv exists. |
245 | |
246 | bool ok = true && (!mayiuse(isa == avx2 ? avx512_core : avx2)) |
247 | && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) |
248 | // TODO: Below may be further tuned. |
249 | && (l2_cache < src_d.size()) |
250 | // load_grp_count check can be redundant due to l2 check |
251 | // above. Adding it explicitly as the current driver doesn't |
252 | // work if this condition fails. |
253 | && (jcp_1x1.load_grp_count < 2); |
254 | if (!ok) return status::unimplemented; |
255 | |
256 | int dw_po_index |
257 | = attr_1x1.post_ops_.find(primitive_kind::convolution); |
258 | |
259 | convolution_desc_t cd_dw; |
260 | primitive_attr_t attr_dw; |
261 | CHECK(get_depthwise_conv_desc( |
262 | cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); |
263 | |
264 | std::unique_ptr<dw_pd_t> fusable_pd( |
265 | new dw_pd_t(&cd_dw, &attr_dw, nullptr)); |
266 | CHECK(fusable_pd->init(engine)); |
267 | jcp_dw_ = &(fusable_pd->jcp_); |
268 | dw_conv_pd_ = std::move(fusable_pd); |
269 | |
270 | ok = true |
271 | && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) |
272 | && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) |
273 | && IMPLICATION(jcp_dw_->ow_block, |
274 | jcp_dw_->ow_block == jcp_dw_->ow); |
275 | if (!ok) return status::unimplemented; |
276 | |
277 | assert(jcp_dw_); |
278 | assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); |
279 | assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); |
280 | assert(IMPLICATION( |
281 | dw_conv_pd_->weights_md(1)->data_type != data_type::undef, |
282 | dw_conv_pd_->weights_md(1)->format_kind |
283 | != format_kind::any)); |
284 | |
285 | jcp_dw_->is_fused_conv = true; |
286 | // TODO: Support/experiment arbitary oc_work in dw conv. |
287 | // Until then we keep ch_work perfectly divisible. |
288 | while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) |
289 | --jcp_1x1.nb_load_blocking; |
290 | jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; |
291 | |
292 | while (jcp_1x1.nb_load_blocking % jcp_dw_->nb_ch_blocking != 0) |
293 | --jcp_dw_->nb_ch_blocking; |
294 | |
295 | jcp_dw_->dw_conv_buffer_oc |
296 | = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; |
297 | jcp_1x1.bcast_loop_output_step = jcp_1x1.ur |
298 | * (jcp_1x1.nb_load_blocking * jcp_1x1.oc_block) |
299 | * jcp_1x1.typesize_out; |
300 | |
301 | registrar_t scratchpad(scratchpad_registry_); |
302 | registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); |
303 | |
304 | size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw_->kh |
305 | * jcp_dw_->iw * jcp_dw_->dw_conv_buffer_oc; |
306 | assert(dw_conv_buffer_size_); |
307 | dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, |
308 | dw_conv_buffer_size_, |
309 | types::data_type_size(dw_conv_pd_->src_md()->data_type)); |
310 | |
311 | dw_conv_kernel_t::init_scratchpad( |
312 | dw_scratchpad, *jcp_dw_, *(dw_conv_pd_->attr())); |
313 | |
314 | return status::success; |
315 | } |
316 | }; |
317 | |
318 | template <cpu_isa_t _isa, typename conv_t> |
319 | friend status_t init_rtus_driver(conv_t *self); |
320 | |
321 | using fusable_pd_type = |
322 | typename jit_uni_x8s8s32x_convolution_fwd_t<isa>::pd_t; |
323 | |
324 | jit_uni_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd) |
325 | : primitive_t(apd) {} |
326 | |
327 | // Note: In case of fused depthwise convolution, the final output data type |
328 | // after fusion may not be same as for dst. |
329 | typedef typename prec_traits<data_type::s32>::type acc_data_t; |
330 | |
331 | status_t init(engine_t *engine) override { |
332 | CHECK(safe_ptr_assign(kernel_, |
333 | new jit_uni_x8s8s32x_1x1_conv_kernel<isa>( |
334 | pd()->jcp_, *pd()->attr(), *pd()->dst_md()))); |
335 | CHECK(kernel_->create_kernel()); |
336 | |
337 | if (pd()->jcp_.with_dw_conv) { |
338 | CHECK(safe_ptr_assign(kernel_dw_, |
339 | new dw_conv_kernel_t(*(pd()->jcp_dw_), |
340 | *(pd()->dw_conv_pd_->attr()), |
341 | *(pd()->dw_conv_pd_->dst_md())))); |
342 | CHECK(kernel_dw_->create_kernel()); |
343 | } |
344 | |
345 | CHECK(init_rtus_driver<isa>(this)); |
346 | return status::success; |
347 | } |
348 | |
349 | status_t execute(const exec_ctx_t &ctx) const override { |
350 | return execute_forward(ctx); |
351 | } |
352 | |
353 | private: |
354 | status_t execute_forward(const exec_ctx_t &ctx) const; |
355 | void execute_forward_thr(const int ithr, const int nthr, const char *src, |
356 | const char *weights, const char *bias, const char *weights_dw, |
357 | const char *bias_dw, char *dst, const float *oscales, |
358 | const float *dst_scales, const float *dw_oscales, |
359 | const float *dw_dst_scales, const int32_t *src_zero_point, |
360 | const int32_t *dst_zero_point, |
361 | const memory_tracking::grantor_t &scratchpad, |
362 | const void *post_ops_binary_rhs_arg_vec, |
363 | const void *post_ops_binary_rhs_arg_vec_dw) const; |
364 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
365 | |
366 | std::unique_ptr<jit_uni_x8s8s32x_1x1_conv_kernel<isa>> kernel_; |
367 | std::unique_ptr<rtus_driver_t<isa>> rtus_driver_; |
368 | using dw_conv_kernel_t = jit_uni_x8s8s32x_fwd_kernel<isa>; |
369 | std::unique_ptr<dw_conv_kernel_t> kernel_dw_; |
370 | }; |
371 | |
372 | } // namespace x64 |
373 | } // namespace cpu |
374 | } // namespace impl |
375 | } // namespace dnnl |
376 | |
377 | #endif |
378 | |