1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_X8S8S32X_1X1_CONVOLUTION_HPP
18#define CPU_X64_JIT_UNI_X8S8S32X_1X1_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/primitive_hashing.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_convolution_pd.hpp"
28#include "cpu/dw_convolution_utils.hpp"
29#include "cpu/platform.hpp"
30
31#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
32#include "cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp"
33#include "cpu/x64/jit_uni_x8s8s32x_convolution.hpp"
34
35namespace dnnl {
36namespace impl {
37namespace cpu {
38namespace x64 {
39
40template <cpu_isa_t isa>
41struct jit_uni_x8s8s32x_1x1_convolution_fwd_t : public primitive_t {
42 struct pd_t : public cpu_convolution_fwd_pd_t {
43 using dw_conv_pd_type = cpu_convolution_fwd_pd_t;
44 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
45 const typename pd_t::base_class *hint_fwd_pd)
46 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
47 , jcp_()
48 , rtus_()
49 , jcp_dw_(nullptr) {}
50
51 pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) {
52 if (copy(other) != status::success) is_initialized_ = false;
53 }
54
55 DECLARE_COMMON_PD_T(
56 JIT_IMPL_NAME_HELPER("jit_uni_int8_1x1:",
57 isa == avx2 && jcp_.has_vnni ? avx2_vnni : isa, ""),
58 jit_uni_x8s8s32x_1x1_convolution_fwd_t);
59
60 status_t init(engine_t *engine) {
61 using namespace data_type;
62 using smask_t = primitive_attr_t::skip_mask_t;
63 bool ok = is_fwd()
64 && set_default_alg_kind(alg_kind::convolution_direct)
65 && utils::one_of(src_md(0)->data_type, s8, u8)
66 && weights_md(0)->data_type == s8
67 && IMPLICATION(with_bias(),
68 utils::one_of(
69 weights_md(1)->data_type, f32, s32, s8, u8))
70 && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8)
71 && desc()->accum_data_type == s32
72 && attr()->has_default_values(smask_t::scales_runtime
73 | smask_t::zero_points_runtime
74 | smask_t::post_ops | smask_t::sum_dt,
75 dst_md(0)->data_type)
76 && attr()->scales_.has_default_values(
77 {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST,
78 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS,
79 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST})
80 && attr()->post_ops_.check_sum_consistent_dt(
81 dst_md(0)->data_type)
82 && !has_zero_dim_memory() && zero_points_ok()
83 && set_default_formats_common(
84 dat_tag(), format_tag::any, dat_tag())
85 && set_or_check_wei_format()
86 && attr_.set_default_formats(dst_md(0)) == status::success;
87 if (!ok) return status::unimplemented;
88
89 const convolution_desc_t *conv_d = desc();
90 const memory_desc_t *src_d = src_md();
91 rtus_prepare(this, conv_d, src_d, dst_md(), weights_md());
92
93 CHECK(jit_uni_x8s8s32x_1x1_conv_kernel<isa>::init_conf(jcp_,
94 *conv_d, *src_d, *weights_md(), *dst_md(),
95 with_bias() ? *weights_md(1) : types::zero_md(), attr_,
96 dnnl_get_max_threads(), rtus_.reduce_src_));
97 if (jcp_.with_dw_conv) CHECK(depthwise_po_init(engine));
98
99 auto scratchpad = scratchpad_registry().registrar();
100 jit_uni_x8s8s32x_1x1_conv_kernel<isa>::init_scratchpad(
101 scratchpad, jcp_, *attr());
102
103 rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
104
105 return status::success;
106 }
107
108 const memory_desc_t *dst_md(int index = 0) const override {
109 return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_;
110 }
111
112 const memory_desc_t *arg_md(int index = 0) const override {
113 if (jcp_.with_dw_conv) {
114 switch (index) {
115 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
116 return dw_conv_pd_->weights_md(0);
117 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
118 return dw_conv_pd_->weights_md(1);
119 default: break;
120 }
121 }
122 return convolution_fwd_pd_t::arg_md(index);
123 }
124
125 arg_usage_t arg_usage(int arg) const override {
126 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
127 return arg_usage_t::input;
128
129 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
130 && attr_post_op_dw_inputs() > 1)
131 return arg_usage_t::input;
132
133 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_OUTPUT_SCALES)
134 && jcp_.with_dw_conv)
135 return arg_usage_t::input;
136
137 return convolution_fwd_pd_t::arg_usage(arg);
138 }
139
140 jit_1x1_conv_conf_t jcp_;
141 reduce_to_unit_stride_t rtus_;
142 jit_conv_conf_t *jcp_dw_; // doesn't own a resource
143 std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_;
144 using dw_pd_t = typename jit_uni_x8s8s32x_convolution_fwd_t<isa>::pd_t;
145
146 protected:
147 bool zero_points_ok() const {
148 // Only common zero points are supported -> mask should only be 0
149 int mask_src = 0, mask_dst = 0;
150 attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src);
151 attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst);
152 return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
153 && mask_src == 0 && mask_dst == 0;
154 }
155
156 status_t copy(const pd_t &other) {
157 jcp_ = other.jcp_;
158 rtus_ = other.rtus_;
159 jcp_dw_ = nullptr;
160 if (other.dw_conv_pd_) {
161 dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>(
162 other.dw_conv_pd_->clone()));
163 if (!dw_conv_pd_) return status::out_of_memory;
164
165 jcp_dw_ = &(static_cast<dw_pd_t *>(dw_conv_pd_.get())->jcp_);
166 }
167 return status::success;
168 }
169
170 format_tag_t dat_tag() const {
171 return utils::pick(ndims() - 3, format_tag::nwc, format_tag::nhwc,
172 format_tag::ndhwc);
173 }
174
175 bool set_or_check_wei_format() {
176 using namespace format_tag;
177 using namespace memory_extra_flags;
178 const auto zp = attr()->zero_points_;
179 const int c_mask = 0x1,
180 g_mask = 0x3; // mask for i/o-channel and ngroups
181
182 const bool is_src_s8 = src_md_.data_type == data_type::s8;
183 const bool is_src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
184 format_tag_t wei_tag;
185 switch (isa) {
186 case avx2:
187 wei_tag = with_groups()
188 ? utils::pick(ndims() - 3, gOIw2i8o4i, gOIhw2i8o4i,
189 gOIdhw2i8o4i)
190 : utils::pick(ndims() - 3, OIw2i8o4i, OIhw2i8o4i,
191 OIdhw2i8o4i);
192 break;
193 case sse41:
194 wei_tag = with_groups() ? utils::pick(ndims() - 3, gOIw4o4i,
195 gOIhw4o4i, gOIdhw4o4i)
196 : utils::pick(ndims() - 3, OIw4o4i,
197 OIhw4o4i, OIdhw4o4i);
198 break;
199 default: assert(!"Current ISA is not supported!"); break;
200 }
201
202 memory_desc_t want_wei_md = weights_md_;
203 memory_desc_init_by_tag(want_wei_md, wei_tag);
204 if (is_src_s8) {
205 want_wei_md.extra.flags
206 = 0 | compensation_conv_s8s8 | scale_adjust;
207 want_wei_md.extra.compensation_mask
208 = with_groups() ? g_mask : c_mask;
209 want_wei_md.extra.scale_adjust
210 = mayiuse(avx2_vnni) ? 1.0f : 0.5f;
211 }
212 if (is_src_zero_point) {
213 want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
214 want_wei_md.extra.asymm_compensation_mask
215 = with_groups() ? g_mask : c_mask;
216 }
217
218 if (weights_md_.format_kind == format_kind::any) {
219 weights_md_ = want_wei_md;
220 return true;
221 }
222
223 return weights_md_ == want_wei_md;
224 }
225
226 status_t depthwise_po_init(engine_t *engine) {
227 using namespace memory_tracking;
228 auto &jcp_1x1 = jcp_;
229 primitive_attr_t attr_1x1(*attr());
230 if (!attr_1x1.is_initialized()) return status::out_of_memory;
231
232 const auto &src_md = dst_md_;
233 const memory_desc_wrapper src_d(src_md);
234 const auto nthr = dnnl_get_max_threads();
235 auto l2_cache = platform::get_per_core_cache_size(2) * nthr;
236
237 // Note: A robust fusion implementation would be to check if both
238 // 1x1 conv and dw conv that are considered here for fusion are
239 // optimal independently. This would require creating a new
240 // primitive_desc through primitive_iterator & check if they match.
241 // Due to concern that these creations and/or checks could be heavy,
242 // for 1x1: Check that no better ISA is available.
243 // for dw: Always fuse with same ISA.
244 // Caveat: May be a better dw conv exists.
245
246 bool ok = true && (!mayiuse(isa == avx2 ? avx512_core : avx2))
247 && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1)
248 // TODO: Below may be further tuned.
249 && (l2_cache < src_d.size())
250 // load_grp_count check can be redundant due to l2 check
251 // above. Adding it explicitly as the current driver doesn't
252 // work if this condition fails.
253 && (jcp_1x1.load_grp_count < 2);
254 if (!ok) return status::unimplemented;
255
256 int dw_po_index
257 = attr_1x1.post_ops_.find(primitive_kind::convolution);
258
259 convolution_desc_t cd_dw;
260 primitive_attr_t attr_dw;
261 CHECK(get_depthwise_conv_desc(
262 cd_dw, src_md, attr_1x1, attr_dw, dw_po_index));
263
264 std::unique_ptr<dw_pd_t> fusable_pd(
265 new dw_pd_t(&cd_dw, &attr_dw, nullptr));
266 CHECK(fusable_pd->init(engine));
267 jcp_dw_ = &(fusable_pd->jcp_);
268 dw_conv_pd_ = std::move(fusable_pd);
269
270 ok = true
271 && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0)))
272 && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0)
273 && IMPLICATION(jcp_dw_->ow_block,
274 jcp_dw_->ow_block == jcp_dw_->ow);
275 if (!ok) return status::unimplemented;
276
277 assert(jcp_dw_);
278 assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any);
279 assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any);
280 assert(IMPLICATION(
281 dw_conv_pd_->weights_md(1)->data_type != data_type::undef,
282 dw_conv_pd_->weights_md(1)->format_kind
283 != format_kind::any));
284
285 jcp_dw_->is_fused_conv = true;
286 // TODO: Support/experiment arbitary oc_work in dw conv.
287 // Until then we keep ch_work perfectly divisible.
288 while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0)
289 --jcp_1x1.nb_load_blocking;
290 jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking;
291
292 while (jcp_1x1.nb_load_blocking % jcp_dw_->nb_ch_blocking != 0)
293 --jcp_dw_->nb_ch_blocking;
294
295 jcp_dw_->dw_conv_buffer_oc
296 = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block;
297 jcp_1x1.bcast_loop_output_step = jcp_1x1.ur
298 * (jcp_1x1.nb_load_blocking * jcp_1x1.oc_block)
299 * jcp_1x1.typesize_out;
300
301 registrar_t scratchpad(scratchpad_registry_);
302 registrar_t dw_scratchpad(scratchpad, names::prefix_fusion);
303
304 size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw_->kh
305 * jcp_dw_->iw * jcp_dw_->dw_conv_buffer_oc;
306 assert(dw_conv_buffer_size_);
307 dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
308 dw_conv_buffer_size_,
309 types::data_type_size(dw_conv_pd_->src_md()->data_type));
310
311 dw_conv_kernel_t::init_scratchpad(
312 dw_scratchpad, *jcp_dw_, *(dw_conv_pd_->attr()));
313
314 return status::success;
315 }
316 };
317
318 template <cpu_isa_t _isa, typename conv_t>
319 friend status_t init_rtus_driver(conv_t *self);
320
321 using fusable_pd_type =
322 typename jit_uni_x8s8s32x_convolution_fwd_t<isa>::pd_t;
323
324 jit_uni_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd)
325 : primitive_t(apd) {}
326
327 // Note: In case of fused depthwise convolution, the final output data type
328 // after fusion may not be same as for dst.
329 typedef typename prec_traits<data_type::s32>::type acc_data_t;
330
331 status_t init(engine_t *engine) override {
332 CHECK(safe_ptr_assign(kernel_,
333 new jit_uni_x8s8s32x_1x1_conv_kernel<isa>(
334 pd()->jcp_, *pd()->attr(), *pd()->dst_md())));
335 CHECK(kernel_->create_kernel());
336
337 if (pd()->jcp_.with_dw_conv) {
338 CHECK(safe_ptr_assign(kernel_dw_,
339 new dw_conv_kernel_t(*(pd()->jcp_dw_),
340 *(pd()->dw_conv_pd_->attr()),
341 *(pd()->dw_conv_pd_->dst_md()))));
342 CHECK(kernel_dw_->create_kernel());
343 }
344
345 CHECK(init_rtus_driver<isa>(this));
346 return status::success;
347 }
348
349 status_t execute(const exec_ctx_t &ctx) const override {
350 return execute_forward(ctx);
351 }
352
353private:
354 status_t execute_forward(const exec_ctx_t &ctx) const;
355 void execute_forward_thr(const int ithr, const int nthr, const char *src,
356 const char *weights, const char *bias, const char *weights_dw,
357 const char *bias_dw, char *dst, const float *oscales,
358 const float *dst_scales, const float *dw_oscales,
359 const float *dw_dst_scales, const int32_t *src_zero_point,
360 const int32_t *dst_zero_point,
361 const memory_tracking::grantor_t &scratchpad,
362 const void *post_ops_binary_rhs_arg_vec,
363 const void *post_ops_binary_rhs_arg_vec_dw) const;
364 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
365
366 std::unique_ptr<jit_uni_x8s8s32x_1x1_conv_kernel<isa>> kernel_;
367 std::unique_ptr<rtus_driver_t<isa>> rtus_driver_;
368 using dw_conv_kernel_t = jit_uni_x8s8s32x_fwd_kernel<isa>;
369 std::unique_ptr<dw_conv_kernel_t> kernel_dw_;
370};
371
372} // namespace x64
373} // namespace cpu
374} // namespace impl
375} // namespace dnnl
376
377#endif
378