1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP
18#define CPU_X64_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_convolution_pd.hpp"
27
28#include "cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public primitive_t {
36 struct pd_t : public cpu_convolution_fwd_pd_t {
37 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
38 const typename pd_t::base_class *hint_fwd_pd)
39 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
40
41 DECLARE_COMMON_PD_T(
42 JIT_IMPL_NAME_HELPER("jit_int8:",
43 (jcp_.has_vnni ? avx512_core_vnni : avx512_core), ""),
44 jit_avx512_core_x8s8s32x_convolution_fwd_t);
45
46 status_t init(engine_t *engine) {
47 using namespace data_type;
48 using smask_t = primitive_attr_t::skip_mask_t;
49 bool ok = is_fwd()
50 && set_default_alg_kind(alg_kind::convolution_direct)
51 && utils::one_of(src_md(0)->data_type, s8, u8)
52 && weights_md(0)->data_type == s8
53 && IMPLICATION(with_bias(),
54 utils::one_of(
55 weights_md(1)->data_type, f32, s32, s8, u8))
56 && utils::one_of(
57 dst_md(0)->data_type, f32, s32, s8, u8, bf16)
58 && desc()->accum_data_type == s32
59 && attr()->has_default_values(smask_t::scales_runtime
60 | smask_t::zero_points_runtime
61 | smask_t::post_ops | smask_t::sum_dt,
62 dst_md(0)->data_type)
63 && attr()->post_ops_.check_sum_consistent_dt(
64 dst_md(0)->data_type)
65 && !has_zero_dim_memory() && zero_points_ok();
66 if (!ok) return status::unimplemented;
67
68 CHECK(jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jcp_, *desc(),
69 src_md_, weights_md_, dst_md_, bias_md_, attr_,
70 dnnl_get_max_threads()));
71
72 auto scratchpad = scratchpad_registry().registrar();
73 jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(
74 scratchpad, jcp_, *attr());
75
76 return status::success;
77 }
78
79 jit_conv_conf_t jcp_;
80
81 protected:
82 bool zero_points_ok() const {
83 // Only common zero points are supported -> mask should only be 0
84 int mask_src = 0, mask_dst = 0;
85 attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src);
86 attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst);
87 return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
88 && mask_src == 0 && mask_dst == 0;
89 }
90 };
91
92 jit_avx512_core_x8s8s32x_convolution_fwd_t(const pd_t *apd)
93 : primitive_t(apd) {}
94
95 status_t init(engine_t *engine) override {
96 CHECK(safe_ptr_assign(kernel_,
97 new jit_avx512_core_x8s8s32x_fwd_kernel(
98 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
99 return kernel_->create_kernel();
100 }
101
102 status_t execute(const exec_ctx_t &ctx) const override {
103 const auto &_pd = pd();
104 if (_pd->ndims() == 3)
105 return execute_forward_1d(ctx);
106 else if (_pd->ndims() == 4)
107 if (_pd->jcp_.is_depthwise)
108 return execute_forward_2d_dw(ctx);
109 else
110 return execute_forward_2d(ctx);
111 else if (_pd->ndims() == 5)
112 return execute_forward_3d(ctx);
113 return status::unimplemented;
114 }
115
116private:
117 status_t execute_forward_1d(const exec_ctx_t &ctx) const;
118 status_t execute_forward_2d(const exec_ctx_t &ctx) const;
119 status_t execute_forward_2d_dw(const exec_ctx_t &ctx) const;
120 status_t execute_forward_3d(const exec_ctx_t &ctx) const;
121 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
122 const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad,
123 const float *src_scales, const float *wei_scales) const;
124
125 std::unique_ptr<jit_avx512_core_x8s8s32x_fwd_kernel> kernel_;
126};
127
128} // namespace x64
129} // namespace cpu
130} // namespace impl
131} // namespace dnnl
132
133#endif
134
135// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
136