1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/primitive_hashing.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | #include "cpu/cpu_convolution_pd.hpp" |
28 | #include "cpu/dw_convolution_utils.hpp" |
29 | #include "cpu/platform.hpp" |
30 | |
31 | #include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" |
32 | #include "cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp" |
33 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
34 | |
35 | namespace dnnl { |
36 | namespace impl { |
37 | namespace cpu { |
38 | namespace x64 { |
39 | |
40 | struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public primitive_t { |
41 | struct pd_t : public cpu_convolution_fwd_pd_t { |
42 | using dw_conv_pd_type = cpu_convolution_fwd_pd_t; |
43 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
44 | const typename pd_t::base_class *hint_fwd_pd) |
45 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) |
46 | , jcp_() |
47 | , rtus_() |
48 | , jcp_dw_(nullptr) {} |
49 | |
50 | pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { |
51 | if (copy(other) != status::success) is_initialized_ = false; |
52 | } |
53 | |
54 | DECLARE_COMMON_PD_T( |
55 | JIT_IMPL_NAME_HELPER("jit_int8_1x1:" , |
56 | ((jcp_.has_vnni) ? avx512_core_vnni : avx512_core), "" ), |
57 | jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t); |
58 | |
59 | status_t init(engine_t *engine) { |
60 | using namespace data_type; |
61 | using smask_t = primitive_attr_t::skip_mask_t; |
62 | bool ok = is_fwd() |
63 | && set_default_alg_kind(alg_kind::convolution_direct) |
64 | && utils::one_of(src_md(0)->data_type, s8, u8) |
65 | && weights_md(0)->data_type == s8 |
66 | && IMPLICATION(with_bias(), weights_md(1)->data_type == f32) |
67 | && utils::one_of( |
68 | dst_md(0)->data_type, f32, s32, s8, u8, bf16) |
69 | && desc()->accum_data_type == s32 |
70 | && attr()->has_default_values(smask_t::scales_runtime |
71 | | smask_t::zero_points_runtime |
72 | | smask_t::post_ops | smask_t::sum_dt, |
73 | dst_md(0)->data_type) |
74 | && attr()->scales_.has_default_values( |
75 | {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST, |
76 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS, |
77 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST}) |
78 | && attr()->post_ops_.check_sum_consistent_dt( |
79 | dst_md(0)->data_type) |
80 | && !has_zero_dim_memory() && zero_points_ok() |
81 | && set_default_formats_common( |
82 | dat_tag(), format_tag::any, dat_tag()) |
83 | && attr_.set_default_formats(dst_md(0)) == status::success; |
84 | if (!ok) return status::unimplemented; |
85 | |
86 | const convolution_desc_t *conv_d = desc(); |
87 | const memory_desc_t *src_d = src_md(); |
88 | rtus_prepare(this, conv_d, src_d, dst_md(), weights_md()); |
89 | |
90 | CHECK(jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(jcp_, |
91 | *conv_d, src_d, weights_md_, dst_md_, bias_md_, *attr(), |
92 | dnnl_get_max_threads(), rtus_.reduce_src_)); |
93 | if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine)); |
94 | |
95 | auto scratchpad = scratchpad_registry().registrar(); |
96 | jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( |
97 | scratchpad, jcp_, *attr()); |
98 | |
99 | rtus_prepare_space_info(this, scratchpad, jcp_.nthr); |
100 | |
101 | return status::success; |
102 | } |
103 | |
104 | const memory_desc_t *dst_md(int index = 0) const override { |
105 | return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_; |
106 | } |
107 | |
108 | const memory_desc_t *arg_md(int index = 0) const override { |
109 | if (jcp_.with_dw_conv) { |
110 | switch (index) { |
111 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: |
112 | return dw_conv_pd_->weights_md(0); |
113 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: |
114 | return dw_conv_pd_->weights_md(1); |
115 | default: break; |
116 | } |
117 | } |
118 | return convolution_fwd_pd_t::arg_md(index); |
119 | } |
120 | |
121 | arg_usage_t arg_usage(int arg) const override { |
122 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) |
123 | return arg_usage_t::input; |
124 | |
125 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) |
126 | && attr_post_op_dw_inputs() > 1) |
127 | return arg_usage_t::input; |
128 | |
129 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_OUTPUT_SCALES) |
130 | && jcp_.with_dw_conv) |
131 | return arg_usage_t::input; |
132 | return convolution_fwd_pd_t::arg_usage(arg); |
133 | } |
134 | |
135 | jit_1x1_conv_conf_t jcp_; |
136 | reduce_to_unit_stride_t rtus_; |
137 | jit_conv_conf_t *jcp_dw_; // doesn't own a resource |
138 | std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_; |
139 | using dw_pd_t = |
140 | typename jit_avx512_core_x8s8s32x_convolution_fwd_t::pd_t; |
141 | |
142 | protected: |
143 | format_tag_t dat_tag() const { |
144 | return utils::pick(src_md_.ndims - 3, format_tag::nwc, |
145 | format_tag::nhwc, format_tag::ndhwc); |
146 | } |
147 | |
148 | bool zero_points_ok() const { |
149 | // Only common zero points are supported -> mask should only be 0 |
150 | int mask_src = 0, mask_dst = 0; |
151 | attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src); |
152 | attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst); |
153 | return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) |
154 | && mask_src == 0 && mask_dst == 0; |
155 | } |
156 | |
157 | status_t copy(const pd_t &other) { |
158 | jcp_ = other.jcp_; |
159 | rtus_ = other.rtus_; |
160 | jcp_dw_ = nullptr; |
161 | |
162 | if (other.dw_conv_pd_) { |
163 | dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>( |
164 | other.dw_conv_pd_->clone())); |
165 | if (!dw_conv_pd_) return status::out_of_memory; |
166 | |
167 | jcp_dw_ = &(static_cast<dw_pd_t *>(dw_conv_pd_.get())->jcp_); |
168 | } |
169 | return status::success; |
170 | } |
171 | |
172 | status_t depthwise_po_init(engine_t *engine) { |
173 | using namespace memory_tracking; |
174 | auto &jcp_1x1 = jcp_; |
175 | primitive_attr_t attr_1x1(*attr()); |
176 | if (!attr_1x1.is_initialized()) return status::out_of_memory; |
177 | |
178 | const auto &src_md = dst_md_; |
179 | const memory_desc_wrapper src_d(src_md); |
180 | const auto nthr = dnnl_get_max_threads(); |
181 | auto l2_cache = platform::get_per_core_cache_size(2) * nthr; |
182 | |
183 | // Note: A robust fusion implementation would be to check if both |
184 | // 1x1 conv and dw conv that are considered here for fusion are |
185 | // optimal independently. This would require creating a new |
186 | // primitive_desc through primitive_iterator & check if they match. |
187 | // Due to concern that these creations and/or checks could be heavy, |
188 | // for 1x1: Check that no better ISA is available. |
189 | // for dw: Always fuse with same ISA. |
190 | // Caveat: May be a better dw conv exists. |
191 | |
192 | bool ok = !mayiuse(avx512_core_amx) |
193 | && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) |
194 | // TODO: Below may be further tuned. |
195 | && (l2_cache < src_d.size()) |
196 | // load_grp_count check can be redundant due to l2 check |
197 | // above. Adding it explicitly as the current driver doesn't |
198 | // work if this condition fails. |
199 | && (jcp_1x1.load_grp_count < 2); |
200 | if (!ok) return status::unimplemented; |
201 | |
202 | int dw_po_index |
203 | = attr_1x1.post_ops_.find(primitive_kind::convolution); |
204 | |
205 | convolution_desc_t cd_dw; |
206 | primitive_attr_t attr_dw; |
207 | CHECK(get_depthwise_conv_desc( |
208 | cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); |
209 | |
210 | std::unique_ptr<dw_pd_t> fusable_pd( |
211 | new dw_pd_t(&cd_dw, &attr_dw, nullptr)); |
212 | CHECK(fusable_pd->init(engine)); |
213 | jcp_dw_ = &(fusable_pd->jcp_); |
214 | dw_conv_pd_ = std::move(fusable_pd); |
215 | |
216 | ok = true |
217 | && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) |
218 | && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) |
219 | && IMPLICATION(jcp_dw_->ow_block, |
220 | jcp_dw_->ow_block == jcp_dw_->ow); |
221 | if (!ok) return status::unimplemented; |
222 | |
223 | assert(jcp_dw_); |
224 | assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); |
225 | assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); |
226 | assert(IMPLICATION( |
227 | dw_conv_pd_->weights_md(1)->data_type != data_type::undef, |
228 | dw_conv_pd_->weights_md(1)->format_kind |
229 | != format_kind::any)); |
230 | |
231 | jcp_dw_->is_fused_conv = true; |
232 | // TODO: Support/experiment arbitary oc_work in dw conv. |
233 | // Until then we keep ch_work perfectly divisible. |
234 | while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) |
235 | --jcp_1x1.nb_load_blocking; |
236 | jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; |
237 | |
238 | while (jcp_1x1.nb_load_blocking % jcp_dw_->nb_ch_blocking != 0) |
239 | --jcp_dw_->nb_ch_blocking; |
240 | |
241 | jcp_dw_->dw_conv_buffer_oc |
242 | = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; |
243 | jcp_1x1.bcast_loop_output_step = jcp_1x1.ur |
244 | * (jcp_1x1.nb_load_blocking * jcp_1x1.oc_block) |
245 | * jcp_1x1.typesize_out; |
246 | |
247 | registrar_t scratchpad(scratchpad_registry_); |
248 | registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); |
249 | |
250 | size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw_->kh |
251 | * jcp_dw_->iw * jcp_dw_->dw_conv_buffer_oc; |
252 | assert(dw_conv_buffer_size_); |
253 | dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, |
254 | dw_conv_buffer_size_, |
255 | types::data_type_size(dw_conv_pd_->src_md()->data_type)); |
256 | |
257 | dw_conv_kernel_t::init_scratchpad( |
258 | dw_scratchpad, *jcp_dw_, *(dw_conv_pd_->attr())); |
259 | return status::success; |
260 | } |
261 | }; |
262 | template <cpu_isa_t isa, typename conv_t> |
263 | friend status_t init_rtus_driver(conv_t *self); |
264 | |
265 | jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd) |
266 | : primitive_t(apd) {} |
267 | |
268 | // Note: In case of fused depthwise convolution, the final output data type |
269 | // after fusion may not be same as for dst. |
270 | typedef typename prec_traits<data_type::s32>::type acc_data_t; |
271 | |
272 | status_t init(engine_t *engine) override { |
273 | CHECK(safe_ptr_assign(kernel_, |
274 | new jit_avx512_core_x8s8s32x_1x1_conv_kernel( |
275 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
276 | CHECK(kernel_->create_kernel()); |
277 | |
278 | if (pd()->jcp_.with_dw_conv) { |
279 | CHECK(safe_ptr_assign(kernel_dw_, |
280 | new dw_conv_kernel_t(*(pd()->jcp_dw_), |
281 | *(pd()->dw_conv_pd_->attr()), *pd()->dst_md(0)))); |
282 | CHECK(kernel_dw_->create_kernel()); |
283 | } |
284 | |
285 | CHECK(init_rtus_driver<avx512_core>(this)); |
286 | return status::success; |
287 | } |
288 | |
289 | status_t execute(const exec_ctx_t &ctx) const override { |
290 | return execute_forward(ctx); |
291 | } |
292 | |
293 | private: |
294 | status_t execute_forward(const exec_ctx_t &ctx) const; |
295 | void execute_forward_thr(const int ithr, const int nthr, const char *src, |
296 | const char *weights, const char *bias, const char *weights_dw, |
297 | const char *bias_dw, char *dst, const float *oscales, |
298 | const float *dst_scales, const float *dw_oscales, |
299 | const float *dw_dst_scales, const int32_t *src_zero_point, |
300 | const int32_t *dst_zero_point, |
301 | const memory_tracking::grantor_t &scratchpad, |
302 | const void *post_ops_binary_rhs_arg_vec, |
303 | const void *post_ops_binary_rhs_arg_vec_dw) const; |
304 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
305 | std::unique_ptr<jit_avx512_core_x8s8s32x_1x1_conv_kernel> kernel_; |
306 | std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_; |
307 | using dw_conv_kernel_t = jit_avx512_core_x8s8s32x_fwd_kernel; |
308 | std::unique_ptr<dw_conv_kernel_t> kernel_dw_; |
309 | }; |
310 | |
311 | } // namespace x64 |
312 | } // namespace cpu |
313 | } // namespace impl |
314 | } // namespace dnnl |
315 | |
316 | #endif |
317 | |