1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_GEMM_INNER_PRODUCT_UTILS_HPP
18#define CPU_GEMM_INNER_PRODUCT_UTILS_HPP
19
20#include "common/broadcast_strategy.hpp"
21#include "common/c_types_map.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/cpu_inner_product_pd.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace inner_product_utils {
31
32struct pp_kernel_t {
33 static pp_kernel_t *create(size_t OC, size_t MB, dim_t dst_mb_stride,
34 const primitive_attr_t *attr, data_type_t bias_dt,
35 data_type_t acc_dt, const memory_desc_t *dst_md, bool skip_sum);
36 static pp_kernel_t *create(
37 const cpu_inner_product_fwd_pd_t *pd, bool skip_sum) {
38 return create(pd->OC(), pd->MB(), pd->OC(), pd->attr(),
39 pd->desc()->bias_desc.data_type, pd->desc()->accum_data_type,
40 pd->dst_md(), skip_sum);
41 }
42
43 virtual ~pp_kernel_t() = default;
44
45 // mb kernel only supports single-threaded execution where performance
46 // degradation is larger
47 bool sequential_kernel() const { return mb_blk_kernel_; }
48
49 virtual void operator()(void *dst, const void *acc, const char *bias,
50 const float *scales, float dst_scale, size_t start,
51 size_t dst_logical_off, size_t dim1_off, size_t end,
52 size_t runtime_oc, dim_t dst_mb_stride,
53 const float *dst_zero_points,
54 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
55 size_t first_mb_matrix_addr_off, const exec_ctx_t &ctx,
56 const memory_desc_t &dst_md) const = 0;
57
58 virtual status_t create_kernel() { return status::success; }
59
60protected:
61 pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride,
62 const primitive_attr_t *attr, data_type_t bias_dt,
63 data_type_t acc_dt, const memory_desc_t *dst_md, bool skip_sum);
64
65 size_t OC_;
66 size_t MB_;
67 dim_t dst_mb_stride_;
68 data_type_t bias_data_type_;
69 data_type_t acc_data_type_;
70 data_type_t dst_data_type_;
71 size_t bias_data_type_size_ = 0;
72 size_t acc_data_type_size_ = 4;
73 size_t dst_data_type_size_ = 0;
74 bool do_scale_ = false;
75 size_t scale_idx_mult_ = 0;
76 bool do_eltwise_ = false;
77 bool do_binary_ = false;
78 bool do_sum_ = false;
79 bool do_dst_scale_ = false;
80 bool do_dst_zero_points_ = false;
81 float sum_scale_ = 0.f;
82 int32_t sum_zp_ = 0;
83 data_type_t sum_data_type_;
84 bool mb_blk_kernel_ = false;
85 post_ops_t post_ops_;
86 int ndims_;
87
88 bool has_trivial_mb_stride() const {
89 return (!runtime_oc()) && (OC_ == (size_t)dst_mb_stride_);
90 }
91 bool do_bias() const { return bias_data_type_ != data_type::undef; }
92 bool runtime_oc() const { return OC_ == (size_t)DNNL_RUNTIME_DIM_VAL; }
93 bool runtime_mb() const { return MB_ == (size_t)DNNL_RUNTIME_DIM_VAL; }
94};
95
96inline const bcast_set_t &gemm_default_strategies() {
97 static const bcast_set_t s
98 = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
99 broadcasting_strategy_t::per_oc_spatial,
100 broadcasting_strategy_t::no_broadcast};
101 return s;
102}
103
104bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d,
105 const bcast_set_t &enabled_bcast_strategy = gemm_default_strategies());
106bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d,
107 const bcast_set_t &enabled_bcast_strategy = gemm_default_strategies());
108
109} // namespace inner_product_utils
110} // namespace cpu
111} // namespace impl
112} // namespace dnnl
113
114#endif
115