1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_GEMM_INNER_PRODUCT_UTILS_HPP |
18 | #define CPU_GEMM_INNER_PRODUCT_UTILS_HPP |
19 | |
20 | #include "common/broadcast_strategy.hpp" |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/cpu_inner_product_pd.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace inner_product_utils { |
31 | |
32 | struct pp_kernel_t { |
33 | static pp_kernel_t *create(size_t OC, size_t MB, dim_t dst_mb_stride, |
34 | const primitive_attr_t *attr, data_type_t bias_dt, |
35 | data_type_t acc_dt, const memory_desc_t *dst_md, bool skip_sum); |
36 | static pp_kernel_t *create( |
37 | const cpu_inner_product_fwd_pd_t *pd, bool skip_sum) { |
38 | return create(pd->OC(), pd->MB(), pd->OC(), pd->attr(), |
39 | pd->desc()->bias_desc.data_type, pd->desc()->accum_data_type, |
40 | pd->dst_md(), skip_sum); |
41 | } |
42 | |
43 | virtual ~pp_kernel_t() = default; |
44 | |
45 | // mb kernel only supports single-threaded execution where performance |
46 | // degradation is larger |
47 | bool sequential_kernel() const { return mb_blk_kernel_; } |
48 | |
49 | virtual void operator()(void *dst, const void *acc, const char *bias, |
50 | const float *scales, float dst_scale, size_t start, |
51 | size_t dst_logical_off, size_t dim1_off, size_t end, |
52 | size_t runtime_oc, dim_t dst_mb_stride, |
53 | const float *dst_zero_points, |
54 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
55 | size_t first_mb_matrix_addr_off, const exec_ctx_t &ctx, |
56 | const memory_desc_t &dst_md) const = 0; |
57 | |
58 | virtual status_t create_kernel() { return status::success; } |
59 | |
60 | protected: |
61 | pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, |
62 | const primitive_attr_t *attr, data_type_t bias_dt, |
63 | data_type_t acc_dt, const memory_desc_t *dst_md, bool skip_sum); |
64 | |
65 | size_t OC_; |
66 | size_t MB_; |
67 | dim_t dst_mb_stride_; |
68 | data_type_t bias_data_type_; |
69 | data_type_t acc_data_type_; |
70 | data_type_t dst_data_type_; |
71 | size_t bias_data_type_size_ = 0; |
72 | size_t acc_data_type_size_ = 4; |
73 | size_t dst_data_type_size_ = 0; |
74 | bool do_scale_ = false; |
75 | size_t scale_idx_mult_ = 0; |
76 | bool do_eltwise_ = false; |
77 | bool do_binary_ = false; |
78 | bool do_sum_ = false; |
79 | bool do_dst_scale_ = false; |
80 | bool do_dst_zero_points_ = false; |
81 | float sum_scale_ = 0.f; |
82 | int32_t sum_zp_ = 0; |
83 | data_type_t sum_data_type_; |
84 | bool mb_blk_kernel_ = false; |
85 | post_ops_t post_ops_; |
86 | int ndims_; |
87 | |
88 | bool has_trivial_mb_stride() const { |
89 | return (!runtime_oc()) && (OC_ == (size_t)dst_mb_stride_); |
90 | } |
91 | bool do_bias() const { return bias_data_type_ != data_type::undef; } |
92 | bool runtime_oc() const { return OC_ == (size_t)DNNL_RUNTIME_DIM_VAL; } |
93 | bool runtime_mb() const { return MB_ == (size_t)DNNL_RUNTIME_DIM_VAL; } |
94 | }; |
95 | |
96 | inline const bcast_set_t &gemm_default_strategies() { |
97 | static const bcast_set_t s |
98 | = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, |
99 | broadcasting_strategy_t::per_oc_spatial, |
100 | broadcasting_strategy_t::no_broadcast}; |
101 | return s; |
102 | } |
103 | |
104 | bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d, |
105 | const bcast_set_t &enabled_bcast_strategy = gemm_default_strategies()); |
106 | bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d, |
107 | const bcast_set_t &enabled_bcast_strategy = gemm_default_strategies()); |
108 | |
109 | } // namespace inner_product_utils |
110 | } // namespace cpu |
111 | } // namespace impl |
112 | } // namespace dnnl |
113 | |
114 | #endif |
115 | |