1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_GEMM_X8S8S32X_CONVOLUTION_UTILS_HPP
18#define CPU_GEMM_X8S8S32X_CONVOLUTION_UTILS_HPP
19
20#include "cpu/gemm_convolution_utils.hpp"
21#if DNNL_X64
22#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
23#endif
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace gemm_x8s8s32x_convolution_utils {
29
30struct pp_ker_t {
31 static pp_ker_t *create(
32 const convolution_pd_t *pd, const conv_gemm_conf_t &jcp);
33 virtual ~pp_ker_t() = default;
34
35 typedef typename prec_traits<data_type::s32>::type acc_data_t;
36
37 virtual void operator()(void *dst, const acc_data_t *acc, const char *bias,
38 const float *scales, float dst_scale, float sum_scale,
39 float signed_scale, int g, size_t start, size_t end,
40 const zero_point_call_params_t &zp,
41 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
42 const exec_ctx_t &ctx, const memory_desc_t &dst_md,
43 const single_gemm_conv_chunk_desc_t &chunk_desc) const = 0;
44
45 virtual status_t create_kernel() { return status::success; }
46
47protected:
48 pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp);
49
50 const conv_gemm_conf_t &jcp_;
51};
52
53bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d);
54bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d);
55bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept;
56
57} // namespace gemm_x8s8s32x_convolution_utils
58} // namespace cpu
59} // namespace impl
60} // namespace dnnl
61
62#endif
63