1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <memory>
18
19#include "common/math_utils.hpp"
20
21#include "cpu/platform.hpp"
22#include "cpu/primitive_attr_postops.hpp"
23#include "cpu/ref_io_helper.hpp"
24#include "cpu/simple_q10n.hpp"
25
26#if DNNL_X64
27#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28#include "cpu/x64/jit_gemm_inner_product_utils.hpp"
29#endif
30
31#include "cpu/gemm_inner_product_utils.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36namespace inner_product_utils {
37
38struct ref_pp_kernel_t : public pp_kernel_t {
39 ref_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride,
40 const primitive_attr_t *attr, data_type_t bias_dt,
41 data_type_t acc_dt, const memory_desc_t *dst_md, bool skip_sum)
42 : pp_kernel_t(
43 OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum)
44 , ref_post_ops_(this->do_sum_ || this->do_eltwise_ || this->do_binary_
45 ? utils::make_unique<ref_post_ops_t>(
46 this->post_ops_, skip_sum)
47 : nullptr) {}
48
49 void operator()(void *dst, const void *acc, const char *bias,
50 const float *scales, float dst_scale, size_t start,
51 size_t dst_logical_offs, size_t dim1_off, size_t end,
52 size_t runtime_oc, dim_t dst_mb_stride,
53 const float *dst_zero_points,
54 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
55 size_t first_mb_matrix_addr_off, const exec_ctx_t &ctx,
56 const memory_desc_t &dst_md) const override;
57
58private:
59 std::unique_ptr<ref_post_ops_t> ref_post_ops_;
60};
61
62void ref_pp_kernel_t::operator()(void *dst, const void *acc, const char *bias,
63 const float *scales, float dst_scale, size_t start,
64 size_t dst_logical_off, size_t dim1_off, size_t end, size_t runtime_oc,
65 dim_t dst_mb_stride, const float *dst_zero_points,
66 const void * /* post_ops_binary_rhs_arg_vec */,
67 const void * /* dst_orig */, size_t /* first_mb_matrix_addr_off */,
68 const exec_ctx_t &ctx, const memory_desc_t &dst_md) const {
69 if (end <= start) return;
70
71 const size_t OC = this->runtime_oc() ? runtime_oc : this->OC_;
72
73 ref_post_ops_t::args_t args;
74 args.ctx = &ctx;
75 args.dst_md = &dst_md;
76 const bool apply_postops
77 = this->do_sum_ || this->do_eltwise_ || this->do_binary_;
78 auto calculate_dst_value_and_increment_oc =
79 [&](const void *acc, void *dst, size_t off, size_t &oc_value,
80 const size_t dst_offset) {
81 float d = io::load_float_value(this->acc_data_type_, acc, off);
82 if (this->do_scale_)
83 d *= scales[oc_value * this->scale_idx_mult_];
84 if (this->do_bias()) {
85 const float b = io::load_float_value(
86 this->bias_data_type_, bias, oc_value);
87 d += b;
88 }
89 if (apply_postops) {
90 if (this->do_sum_)
91 args.dst_val = io::load_float_value(
92 this->sum_data_type_, dst, off);
93 args.l_offset = dst_offset;
94 ref_post_ops_->execute(d, args);
95 }
96 if (this->do_dst_scale_) d *= dst_scale;
97 if (this->do_dst_zero_points_) d += dst_zero_points[0];
98 io::store_float_value(this->dst_data_type_, d, dst, off);
99 oc_value = (oc_value == OC - 1) ? 0 : oc_value + 1;
100 };
101
102 size_t oc = start % OC;
103 dim_t src1_bin_po_offt = dst_logical_off;
104 if (this->has_trivial_mb_stride()) {
105 // keep separate code path to avoid performance degradations
106 for (size_t i = start; i < end; i++) {
107 calculate_dst_value_and_increment_oc(
108 acc, dst, i, oc, src1_bin_po_offt);
109 ++src1_bin_po_offt;
110 }
111 } else {
112 const dim_t offt = (start / OC) * dst_mb_stride + oc;
113 const bool acc_is_dst = dst == acc;
114 dst = static_cast<char *>(dst) + this->dst_data_type_size_ * offt;
115 // if dst and acc point to same address (inplace), then strides
116 // must be similar, else assume acc buffer is dense.
117 acc = static_cast<const char *>(acc)
118 + this->acc_data_type_size_ * (acc_is_dst ? offt : start);
119 size_t i_elem = 0;
120 while (start < end) {
121 calculate_dst_value_and_increment_oc(
122 acc, dst, i_elem, oc, src1_bin_po_offt);
123 if (oc == 0) {
124 const auto stride = dst_mb_stride - OC;
125 dst = static_cast<char *>(dst)
126 + this->dst_data_type_size_ * stride;
127 // if dst and acc point to same address (inplace), then strides
128 // must be similar, else assume acc buffer is dense.
129 if (acc_is_dst)
130 acc = static_cast<const char *>(acc)
131 + this->acc_data_type_size_ * stride;
132 }
133 ++src1_bin_po_offt;
134 ++start;
135 ++i_elem;
136 }
137 }
138}
139
140// Interface section
141
142pp_kernel_t::pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride,
143 const primitive_attr_t *attr, data_type_t bias_dt, data_type_t acc_dt,
144 const memory_desc_t *dst_md, bool skip_sum)
145 : OC_(OC)
146 , MB_(MB)
147 , dst_mb_stride_(dst_mb_stride)
148 , bias_data_type_(bias_dt)
149 , acc_data_type_(acc_dt)
150 , dst_data_type_(dst_md->data_type)
151 , ndims_(dst_md->ndims) {
152 do_scale_ = !attr->scales_.get(DNNL_ARG_SRC).has_default_values()
153 || !attr->scales_.get(DNNL_ARG_WEIGHTS).has_default_values();
154 if (do_scale_) {
155 int wei_mask = attr->scales_.get(DNNL_ARG_WEIGHTS).mask_;
156 // matmul: per_oc: 1 << (ndims_ - 1)
157 // ip: per_oc: 1 << 0
158 scale_idx_mult_ = wei_mask == (1 << (ndims_ - 1)) || wei_mask == 1 << 0;
159 }
160 do_dst_scale_ = !attr->scales_.get(DNNL_ARG_DST).has_default_values();
161
162 post_ops_ = attr->post_ops_;
163 const int eltwise_ind = post_ops_.find(primitive_kind::eltwise);
164 do_eltwise_ = eltwise_ind != -1;
165
166 const int binary_ind = post_ops_.find(primitive_kind::binary);
167 do_binary_ = binary_ind != -1;
168
169 const int sum_ind = post_ops_.find(primitive_kind::sum);
170 do_sum_ = sum_ind != -1 && !skip_sum;
171 if (do_sum_) {
172 sum_scale_ = post_ops_.entry_[sum_ind].sum.scale;
173 sum_zp_ = post_ops_.entry_[sum_ind].sum.zero_point;
174 const auto &sum_dt = post_ops_.entry_[sum_ind].sum.dt;
175 sum_data_type_ = sum_dt != data_type::undef ? sum_dt : dst_data_type_;
176 }
177
178 dst_data_type_size_ = types::data_type_size(dst_data_type_);
179 if (do_bias())
180 bias_data_type_size_ = types::data_type_size(bias_data_type_);
181
182 if (!attr->zero_points_.has_default_values(DNNL_ARG_DST))
183 do_dst_zero_points_ = true;
184}
185
186pp_kernel_t *pp_kernel_t::create(size_t OC, size_t MB, dim_t dst_mb_stride,
187 const primitive_attr_t *attr, data_type_t bias_dt, data_type_t acc_dt,
188 const memory_desc_t *dst_md, bool skip_sum) {
189#if DNNL_X64
190 auto *res = x64::inner_product_utils::jit_pp_kernel_create(
191 OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum);
192 if (res) return res;
193#endif
194
195 return new ref_pp_kernel_t(
196 OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum);
197}
198
199bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d,
200 const bcast_set_t &enabled_bcast_strategy) {
201#if DNNL_X64
202 static constexpr auto isa_supported
203 = x64::inner_product_utils::jit_pp_kernel_supported_isa();
204 using namespace cpu::x64;
205 if (mayiuse(isa_supported)) {
206 using namespace x64::injector;
207 static constexpr bool sum_at_pos_0_only = true;
208 static constexpr bool sum_requires_scale_one = false;
209 static constexpr bool sum_requires_zp_zero = false;
210 const auto ndims = dst_d->ndims();
211
212 bool is_binary_po_channel_bcast {};
213 bool is_binary_po_per_mb_w_bcast {};
214 bool is_binary_po_per_w_bcast {};
215 std::tie(is_binary_po_channel_bcast, is_binary_po_per_mb_w_bcast,
216 is_binary_po_per_w_bcast)
217 = binary_injector_utils::bcast_strategies_present_tup(
218 post_ops.entry_, *dst_d,
219 broadcasting_strategy_t::per_mb_spatial,
220 broadcasting_strategy_t::per_mb_w,
221 broadcasting_strategy_t::per_w);
222 const bool supported_binary_bcast
223 = IMPLICATION(is_binary_po_channel_bcast,
224 utils::one_of(ndims, 3, 4))
225 && IMPLICATION(
226 is_binary_po_per_mb_w_bcast, utils::one_of(ndims, 3, 4))
227 && IMPLICATION(
228 is_binary_po_per_w_bcast, utils::one_of(ndims, 3, 4));
229 const cpu_isa_t isa = get_max_cpu_isa();
230 return supported_binary_bcast
231 && injector::post_ops_ok({isa, {binary, eltwise, sum}, post_ops,
232 dst_d, sum_at_pos_0_only, sum_requires_scale_one,
233 sum_requires_zp_zero, enabled_bcast_strategy});
234 }
235#endif
236 for (size_t i = 0; i < post_ops.entry_.size(); i++) {
237 const auto &post_op = post_ops.entry_[i];
238 const bool sum_postop_present = post_op.is_sum(false);
239 if (sum_postop_present && i > 0) return false;
240 if (!(sum_postop_present || post_op.is_eltwise()
241 || post_op.is_binary()))
242 return false;
243 }
244 return true;
245}
246
247bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d,
248 const bcast_set_t &enabled_bcast_strategy) {
249 const auto dst_md = memory_desc_wrapper(dst_d);
250 return post_ops_ok(post_ops, &dst_md, enabled_bcast_strategy);
251}
252
253} // namespace inner_product_utils
254} // namespace cpu
255} // namespace impl
256} // namespace dnnl
257