1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_X8S8S32X_CONVOLUTION_HPP
18#define CPU_X64_JIT_UNI_X8S8S32X_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_convolution_pd.hpp"
27
28#include "cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35template <cpu_isa_t isa>
36struct jit_uni_x8s8s32x_convolution_fwd_t : public primitive_t {
37 struct pd_t : public cpu_convolution_fwd_pd_t {
38 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
39 const typename pd_t::base_class *hint_fwd_pd)
40 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
41
42 DECLARE_COMMON_PD_T(
43 JIT_IMPL_NAME_HELPER("jit_uni_int8:",
44 isa == avx2 && jcp_.has_vnni ? avx2_vnni : isa, ""),
45 jit_uni_x8s8s32x_convolution_fwd_t);
46
47 status_t init(engine_t *engine) {
48 using namespace data_type;
49 using smask_t = primitive_attr_t::skip_mask_t;
50 const bool args_ok = is_fwd()
51 && set_default_alg_kind(alg_kind::convolution_direct)
52 && utils::one_of(src_md(0)->data_type, s8, u8)
53 && weights_md(0)->data_type == s8
54 && IMPLICATION(with_bias(),
55 utils::one_of(
56 weights_md(1)->data_type, f32, s32, s8, u8))
57 && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8)
58 && desc()->accum_data_type == s32
59 && attr()->has_default_values(smask_t::scales_runtime
60 | smask_t::zero_points_runtime
61 | smask_t::post_ops | smask_t::sum_dt,
62 dst_md(0)->data_type)
63 && attr()->post_ops_.check_sum_consistent_dt(
64 dst_md(0)->data_type)
65 && !has_zero_dim_memory() && zero_points_ok();
66 if (!args_ok) return status::unimplemented;
67
68 CHECK(jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jcp_, *desc(),
69 src_md_, weights_md_, dst_md_, bias_md_, attr_,
70 dnnl_get_max_threads()));
71
72 auto scratchpad = scratchpad_registry().registrar();
73 jit_uni_x8s8s32x_fwd_kernel<isa>::init_scratchpad(
74 scratchpad, jcp_, *attr());
75
76 return attr_.set_default_formats(dst_md(0));
77 }
78
79 jit_conv_conf_t jcp_;
80
81 protected:
82 bool zero_points_ok() const {
83 // Only common zero points are supported -> mask should only be 0
84 int mask_src = 0, mask_dst = 0;
85 attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src);
86 attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst);
87 return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
88 && mask_src == 0 && mask_dst == 0;
89 }
90 };
91
92 jit_uni_x8s8s32x_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
93
94 status_t init(engine_t *engine) override {
95 CHECK(safe_ptr_assign(kernel_,
96 new jit_uni_x8s8s32x_fwd_kernel<isa>(
97 pd()->jcp_, *pd()->attr(), *pd()->dst_md())));
98 return kernel_->create_kernel();
99 }
100
101 status_t execute(const exec_ctx_t &ctx) const override {
102 const auto &_pd = pd();
103 const int ndims = _pd->ndims();
104 const bool is_dw = _pd->jcp_.is_depthwise;
105
106 switch (ndims) {
107 case 3: return execute_forward_1d(ctx);
108 case 4:
109 if (is_dw) return execute_forward_2d_dw(ctx);
110 return execute_forward_2d(ctx);
111 case 5: return execute_forward_3d(ctx);
112 }
113 return status::unimplemented;
114 }
115
116private:
117 status_t execute_forward_1d(const exec_ctx_t &ctx) const;
118 status_t execute_forward_2d(const exec_ctx_t &ctx) const;
119 status_t execute_forward_3d(const exec_ctx_t &ctx) const;
120 status_t execute_forward_2d_dw(const exec_ctx_t &ctx) const;
121 const pd_t *pd() const {
122 return static_cast<const pd_t *>(primitive_t::pd().get());
123 }
124 const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad,
125 const float *src_scales, const float *wei_scales) const;
126
127 std::unique_ptr<jit_uni_x8s8s32x_fwd_kernel<isa>> kernel_;
128};
129
130} // namespace x64
131} // namespace cpu
132} // namespace impl
133} // namespace dnnl
134
135#endif
136