1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP
18#define CPU_X64_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/primitive_hashing.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_convolution_pd.hpp"
28#include "cpu/dw_convolution_utils.hpp"
29#include "cpu/platform.hpp"
30
31#include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp"
32#include "cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp"
33#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
34
35namespace dnnl {
36namespace impl {
37namespace cpu {
38namespace x64 {
39
40struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public primitive_t {
41 struct pd_t : public cpu_convolution_fwd_pd_t {
42 using dw_conv_pd_type = cpu_convolution_fwd_pd_t;
43 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
44 const typename pd_t::base_class *hint_fwd_pd)
45 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
46 , jcp_()
47 , rtus_()
48 , jcp_dw_(nullptr) {}
49
50 pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) {
51 if (copy(other) != status::success) is_initialized_ = false;
52 }
53
54 DECLARE_COMMON_PD_T(
55 JIT_IMPL_NAME_HELPER("jit_int8_1x1:",
56 ((jcp_.has_vnni) ? avx512_core_vnni : avx512_core), ""),
57 jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t);
58
59 status_t init(engine_t *engine) {
60 using namespace data_type;
61 using smask_t = primitive_attr_t::skip_mask_t;
62 bool ok = is_fwd()
63 && set_default_alg_kind(alg_kind::convolution_direct)
64 && utils::one_of(src_md(0)->data_type, s8, u8)
65 && weights_md(0)->data_type == s8
66 && IMPLICATION(with_bias(), weights_md(1)->data_type == f32)
67 && utils::one_of(
68 dst_md(0)->data_type, f32, s32, s8, u8, bf16)
69 && desc()->accum_data_type == s32
70 && attr()->has_default_values(smask_t::scales_runtime
71 | smask_t::zero_points_runtime
72 | smask_t::post_ops | smask_t::sum_dt,
73 dst_md(0)->data_type)
74 && attr()->scales_.has_default_values(
75 {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST,
76 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS,
77 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST})
78 && attr()->post_ops_.check_sum_consistent_dt(
79 dst_md(0)->data_type)
80 && !has_zero_dim_memory() && zero_points_ok()
81 && set_default_formats_common(
82 dat_tag(), format_tag::any, dat_tag())
83 && attr_.set_default_formats(dst_md(0)) == status::success;
84 if (!ok) return status::unimplemented;
85
86 const convolution_desc_t *conv_d = desc();
87 const memory_desc_t *src_d = src_md();
88 rtus_prepare(this, conv_d, src_d, dst_md(), weights_md());
89
90 CHECK(jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(jcp_,
91 *conv_d, src_d, weights_md_, dst_md_, bias_md_, *attr(),
92 dnnl_get_max_threads(), rtus_.reduce_src_));
93 if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine));
94
95 auto scratchpad = scratchpad_registry().registrar();
96 jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad(
97 scratchpad, jcp_, *attr());
98
99 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
100
101 return status::success;
102 }
103
104 const memory_desc_t *dst_md(int index = 0) const override {
105 return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_;
106 }
107
108 const memory_desc_t *arg_md(int index = 0) const override {
109 if (jcp_.with_dw_conv) {
110 switch (index) {
111 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
112 return dw_conv_pd_->weights_md(0);
113 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
114 return dw_conv_pd_->weights_md(1);
115 default: break;
116 }
117 }
118 return convolution_fwd_pd_t::arg_md(index);
119 }
120
121 arg_usage_t arg_usage(int arg) const override {
122 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
123 return arg_usage_t::input;
124
125 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
126 && attr_post_op_dw_inputs() > 1)
127 return arg_usage_t::input;
128
129 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_OUTPUT_SCALES)
130 && jcp_.with_dw_conv)
131 return arg_usage_t::input;
132 return convolution_fwd_pd_t::arg_usage(arg);
133 }
134
135 jit_1x1_conv_conf_t jcp_;
136 reduce_to_unit_stride_t rtus_;
137 jit_conv_conf_t *jcp_dw_; // doesn't own a resource
138 std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_;
139 using dw_pd_t =
140 typename jit_avx512_core_x8s8s32x_convolution_fwd_t::pd_t;
141
142 protected:
143 format_tag_t dat_tag() const {
144 return utils::pick(src_md_.ndims - 3, format_tag::nwc,
145 format_tag::nhwc, format_tag::ndhwc);
146 }
147
148 bool zero_points_ok() const {
149 // Only common zero points are supported -> mask should only be 0
150 int mask_src = 0, mask_dst = 0;
151 attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src);
152 attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst);
153 return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
154 && mask_src == 0 && mask_dst == 0;
155 }
156
157 status_t copy(const pd_t &other) {
158 jcp_ = other.jcp_;
159 rtus_ = other.rtus_;
160 jcp_dw_ = nullptr;
161
162 if (other.dw_conv_pd_) {
163 dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>(
164 other.dw_conv_pd_->clone()));
165 if (!dw_conv_pd_) return status::out_of_memory;
166
167 jcp_dw_ = &(static_cast<dw_pd_t *>(dw_conv_pd_.get())->jcp_);
168 }
169 return status::success;
170 }
171
172 status_t depthwise_po_init(engine_t *engine) {
173 using namespace memory_tracking;
174 auto &jcp_1x1 = jcp_;
175 primitive_attr_t attr_1x1(*attr());
176 if (!attr_1x1.is_initialized()) return status::out_of_memory;
177
178 const auto &src_md = dst_md_;
179 const memory_desc_wrapper src_d(src_md);
180 const auto nthr = dnnl_get_max_threads();
181 auto l2_cache = platform::get_per_core_cache_size(2) * nthr;
182
183 // Note: A robust fusion implementation would be to check if both
184 // 1x1 conv and dw conv that are considered here for fusion are
185 // optimal independently. This would require creating a new
186 // primitive_desc through primitive_iterator & check if they match.
187 // Due to concern that these creations and/or checks could be heavy,
188 // for 1x1: Check that no better ISA is available.
189 // for dw: Always fuse with same ISA.
190 // Caveat: May be a better dw conv exists.
191
192 bool ok = !mayiuse(avx512_core_amx)
193 && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1)
194 // TODO: Below may be further tuned.
195 && (l2_cache < src_d.size())
196 // load_grp_count check can be redundant due to l2 check
197 // above. Adding it explicitly as the current driver doesn't
198 // work if this condition fails.
199 && (jcp_1x1.load_grp_count < 2);
200 if (!ok) return status::unimplemented;
201
202 int dw_po_index
203 = attr_1x1.post_ops_.find(primitive_kind::convolution);
204
205 convolution_desc_t cd_dw;
206 primitive_attr_t attr_dw;
207 CHECK(get_depthwise_conv_desc(
208 cd_dw, src_md, attr_1x1, attr_dw, dw_po_index));
209
210 std::unique_ptr<dw_pd_t> fusable_pd(
211 new dw_pd_t(&cd_dw, &attr_dw, nullptr));
212 CHECK(fusable_pd->init(engine));
213 jcp_dw_ = &(fusable_pd->jcp_);
214 dw_conv_pd_ = std::move(fusable_pd);
215
216 ok = true
217 && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0)))
218 && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0)
219 && IMPLICATION(jcp_dw_->ow_block,
220 jcp_dw_->ow_block == jcp_dw_->ow);
221 if (!ok) return status::unimplemented;
222
223 assert(jcp_dw_);
224 assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any);
225 assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any);
226 assert(IMPLICATION(
227 dw_conv_pd_->weights_md(1)->data_type != data_type::undef,
228 dw_conv_pd_->weights_md(1)->format_kind
229 != format_kind::any));
230
231 jcp_dw_->is_fused_conv = true;
232 // TODO: Support/experiment arbitary oc_work in dw conv.
233 // Until then we keep ch_work perfectly divisible.
234 while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0)
235 --jcp_1x1.nb_load_blocking;
236 jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking;
237
238 while (jcp_1x1.nb_load_blocking % jcp_dw_->nb_ch_blocking != 0)
239 --jcp_dw_->nb_ch_blocking;
240
241 jcp_dw_->dw_conv_buffer_oc
242 = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block;
243 jcp_1x1.bcast_loop_output_step = jcp_1x1.ur
244 * (jcp_1x1.nb_load_blocking * jcp_1x1.oc_block)
245 * jcp_1x1.typesize_out;
246
247 registrar_t scratchpad(scratchpad_registry_);
248 registrar_t dw_scratchpad(scratchpad, names::prefix_fusion);
249
250 size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw_->kh
251 * jcp_dw_->iw * jcp_dw_->dw_conv_buffer_oc;
252 assert(dw_conv_buffer_size_);
253 dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
254 dw_conv_buffer_size_,
255 types::data_type_size(dw_conv_pd_->src_md()->data_type));
256
257 dw_conv_kernel_t::init_scratchpad(
258 dw_scratchpad, *jcp_dw_, *(dw_conv_pd_->attr()));
259 return status::success;
260 }
261 };
262 template <cpu_isa_t isa, typename conv_t>
263 friend status_t init_rtus_driver(conv_t *self);
264
265 jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd)
266 : primitive_t(apd) {}
267
268 // Note: In case of fused depthwise convolution, the final output data type
269 // after fusion may not be same as for dst.
270 typedef typename prec_traits<data_type::s32>::type acc_data_t;
271
272 status_t init(engine_t *engine) override {
273 CHECK(safe_ptr_assign(kernel_,
274 new jit_avx512_core_x8s8s32x_1x1_conv_kernel(
275 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
276 CHECK(kernel_->create_kernel());
277
278 if (pd()->jcp_.with_dw_conv) {
279 CHECK(safe_ptr_assign(kernel_dw_,
280 new dw_conv_kernel_t(*(pd()->jcp_dw_),
281 *(pd()->dw_conv_pd_->attr()), *pd()->dst_md(0))));
282 CHECK(kernel_dw_->create_kernel());
283 }
284
285 CHECK(init_rtus_driver<avx512_core>(this));
286 return status::success;
287 }
288
289 status_t execute(const exec_ctx_t &ctx) const override {
290 return execute_forward(ctx);
291 }
292
293private:
294 status_t execute_forward(const exec_ctx_t &ctx) const;
295 void execute_forward_thr(const int ithr, const int nthr, const char *src,
296 const char *weights, const char *bias, const char *weights_dw,
297 const char *bias_dw, char *dst, const float *oscales,
298 const float *dst_scales, const float *dw_oscales,
299 const float *dw_dst_scales, const int32_t *src_zero_point,
300 const int32_t *dst_zero_point,
301 const memory_tracking::grantor_t &scratchpad,
302 const void *post_ops_binary_rhs_arg_vec,
303 const void *post_ops_binary_rhs_arg_vec_dw) const;
304 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
305 std::unique_ptr<jit_avx512_core_x8s8s32x_1x1_conv_kernel> kernel_;
306 std::unique_ptr<rtus_driver_t<avx512_core>> rtus_driver_;
307 using dw_conv_kernel_t = jit_avx512_core_x8s8s32x_fwd_kernel;
308 std::unique_ptr<dw_conv_kernel_t> kernel_dw_;
309};
310
311} // namespace x64
312} // namespace cpu
313} // namespace impl
314} // namespace dnnl
315
316#endif
317