1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <memory> |
18 | |
19 | #include "common/math_utils.hpp" |
20 | |
21 | #include "cpu/platform.hpp" |
22 | #include "cpu/primitive_attr_postops.hpp" |
23 | #include "cpu/ref_io_helper.hpp" |
24 | #include "cpu/simple_q10n.hpp" |
25 | |
26 | #if DNNL_X64 |
27 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
28 | #include "cpu/x64/jit_gemm_inner_product_utils.hpp" |
29 | #endif |
30 | |
31 | #include "cpu/gemm_inner_product_utils.hpp" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | namespace inner_product_utils { |
37 | |
38 | struct ref_pp_kernel_t : public pp_kernel_t { |
39 | ref_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, |
40 | const primitive_attr_t *attr, data_type_t bias_dt, |
41 | data_type_t acc_dt, const memory_desc_t *dst_md, bool skip_sum) |
42 | : pp_kernel_t( |
43 | OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum) |
44 | , ref_post_ops_(this->do_sum_ || this->do_eltwise_ || this->do_binary_ |
45 | ? utils::make_unique<ref_post_ops_t>( |
46 | this->post_ops_, skip_sum) |
47 | : nullptr) {} |
48 | |
49 | void operator()(void *dst, const void *acc, const char *bias, |
50 | const float *scales, float dst_scale, size_t start, |
51 | size_t dst_logical_offs, size_t dim1_off, size_t end, |
52 | size_t runtime_oc, dim_t dst_mb_stride, |
53 | const float *dst_zero_points, |
54 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
55 | size_t first_mb_matrix_addr_off, const exec_ctx_t &ctx, |
56 | const memory_desc_t &dst_md) const override; |
57 | |
58 | private: |
59 | std::unique_ptr<ref_post_ops_t> ref_post_ops_; |
60 | }; |
61 | |
62 | void ref_pp_kernel_t::operator()(void *dst, const void *acc, const char *bias, |
63 | const float *scales, float dst_scale, size_t start, |
64 | size_t dst_logical_off, size_t dim1_off, size_t end, size_t runtime_oc, |
65 | dim_t dst_mb_stride, const float *dst_zero_points, |
66 | const void * /* post_ops_binary_rhs_arg_vec */, |
67 | const void * /* dst_orig */, size_t /* first_mb_matrix_addr_off */, |
68 | const exec_ctx_t &ctx, const memory_desc_t &dst_md) const { |
69 | if (end <= start) return; |
70 | |
71 | const size_t OC = this->runtime_oc() ? runtime_oc : this->OC_; |
72 | |
73 | ref_post_ops_t::args_t args; |
74 | args.ctx = &ctx; |
75 | args.dst_md = &dst_md; |
76 | const bool apply_postops |
77 | = this->do_sum_ || this->do_eltwise_ || this->do_binary_; |
78 | auto calculate_dst_value_and_increment_oc = |
79 | [&](const void *acc, void *dst, size_t off, size_t &oc_value, |
80 | const size_t dst_offset) { |
81 | float d = io::load_float_value(this->acc_data_type_, acc, off); |
82 | if (this->do_scale_) |
83 | d *= scales[oc_value * this->scale_idx_mult_]; |
84 | if (this->do_bias()) { |
85 | const float b = io::load_float_value( |
86 | this->bias_data_type_, bias, oc_value); |
87 | d += b; |
88 | } |
89 | if (apply_postops) { |
90 | if (this->do_sum_) |
91 | args.dst_val = io::load_float_value( |
92 | this->sum_data_type_, dst, off); |
93 | args.l_offset = dst_offset; |
94 | ref_post_ops_->execute(d, args); |
95 | } |
96 | if (this->do_dst_scale_) d *= dst_scale; |
97 | if (this->do_dst_zero_points_) d += dst_zero_points[0]; |
98 | io::store_float_value(this->dst_data_type_, d, dst, off); |
99 | oc_value = (oc_value == OC - 1) ? 0 : oc_value + 1; |
100 | }; |
101 | |
102 | size_t oc = start % OC; |
103 | dim_t src1_bin_po_offt = dst_logical_off; |
104 | if (this->has_trivial_mb_stride()) { |
105 | // keep separate code path to avoid performance degradations |
106 | for (size_t i = start; i < end; i++) { |
107 | calculate_dst_value_and_increment_oc( |
108 | acc, dst, i, oc, src1_bin_po_offt); |
109 | ++src1_bin_po_offt; |
110 | } |
111 | } else { |
112 | const dim_t offt = (start / OC) * dst_mb_stride + oc; |
113 | const bool acc_is_dst = dst == acc; |
114 | dst = static_cast<char *>(dst) + this->dst_data_type_size_ * offt; |
115 | // if dst and acc point to same address (inplace), then strides |
116 | // must be similar, else assume acc buffer is dense. |
117 | acc = static_cast<const char *>(acc) |
118 | + this->acc_data_type_size_ * (acc_is_dst ? offt : start); |
119 | size_t i_elem = 0; |
120 | while (start < end) { |
121 | calculate_dst_value_and_increment_oc( |
122 | acc, dst, i_elem, oc, src1_bin_po_offt); |
123 | if (oc == 0) { |
124 | const auto stride = dst_mb_stride - OC; |
125 | dst = static_cast<char *>(dst) |
126 | + this->dst_data_type_size_ * stride; |
127 | // if dst and acc point to same address (inplace), then strides |
128 | // must be similar, else assume acc buffer is dense. |
129 | if (acc_is_dst) |
130 | acc = static_cast<const char *>(acc) |
131 | + this->acc_data_type_size_ * stride; |
132 | } |
133 | ++src1_bin_po_offt; |
134 | ++start; |
135 | ++i_elem; |
136 | } |
137 | } |
138 | } |
139 | |
140 | // Interface section |
141 | |
142 | pp_kernel_t::pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, |
143 | const primitive_attr_t *attr, data_type_t bias_dt, data_type_t acc_dt, |
144 | const memory_desc_t *dst_md, bool skip_sum) |
145 | : OC_(OC) |
146 | , MB_(MB) |
147 | , dst_mb_stride_(dst_mb_stride) |
148 | , bias_data_type_(bias_dt) |
149 | , acc_data_type_(acc_dt) |
150 | , dst_data_type_(dst_md->data_type) |
151 | , ndims_(dst_md->ndims) { |
152 | do_scale_ = !attr->scales_.get(DNNL_ARG_SRC).has_default_values() |
153 | || !attr->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); |
154 | if (do_scale_) { |
155 | int wei_mask = attr->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
156 | // matmul: per_oc: 1 << (ndims_ - 1) |
157 | // ip: per_oc: 1 << 0 |
158 | scale_idx_mult_ = wei_mask == (1 << (ndims_ - 1)) || wei_mask == 1 << 0; |
159 | } |
160 | do_dst_scale_ = !attr->scales_.get(DNNL_ARG_DST).has_default_values(); |
161 | |
162 | post_ops_ = attr->post_ops_; |
163 | const int eltwise_ind = post_ops_.find(primitive_kind::eltwise); |
164 | do_eltwise_ = eltwise_ind != -1; |
165 | |
166 | const int binary_ind = post_ops_.find(primitive_kind::binary); |
167 | do_binary_ = binary_ind != -1; |
168 | |
169 | const int sum_ind = post_ops_.find(primitive_kind::sum); |
170 | do_sum_ = sum_ind != -1 && !skip_sum; |
171 | if (do_sum_) { |
172 | sum_scale_ = post_ops_.entry_[sum_ind].sum.scale; |
173 | sum_zp_ = post_ops_.entry_[sum_ind].sum.zero_point; |
174 | const auto &sum_dt = post_ops_.entry_[sum_ind].sum.dt; |
175 | sum_data_type_ = sum_dt != data_type::undef ? sum_dt : dst_data_type_; |
176 | } |
177 | |
178 | dst_data_type_size_ = types::data_type_size(dst_data_type_); |
179 | if (do_bias()) |
180 | bias_data_type_size_ = types::data_type_size(bias_data_type_); |
181 | |
182 | if (!attr->zero_points_.has_default_values(DNNL_ARG_DST)) |
183 | do_dst_zero_points_ = true; |
184 | } |
185 | |
186 | pp_kernel_t *pp_kernel_t::create(size_t OC, size_t MB, dim_t dst_mb_stride, |
187 | const primitive_attr_t *attr, data_type_t bias_dt, data_type_t acc_dt, |
188 | const memory_desc_t *dst_md, bool skip_sum) { |
189 | #if DNNL_X64 |
190 | auto *res = x64::inner_product_utils::jit_pp_kernel_create( |
191 | OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum); |
192 | if (res) return res; |
193 | #endif |
194 | |
195 | return new ref_pp_kernel_t( |
196 | OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum); |
197 | } |
198 | |
199 | bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d, |
200 | const bcast_set_t &enabled_bcast_strategy) { |
201 | #if DNNL_X64 |
202 | static constexpr auto isa_supported |
203 | = x64::inner_product_utils::jit_pp_kernel_supported_isa(); |
204 | using namespace cpu::x64; |
205 | if (mayiuse(isa_supported)) { |
206 | using namespace x64::injector; |
207 | static constexpr bool sum_at_pos_0_only = true; |
208 | static constexpr bool sum_requires_scale_one = false; |
209 | static constexpr bool sum_requires_zp_zero = false; |
210 | const auto ndims = dst_d->ndims(); |
211 | |
212 | bool is_binary_po_channel_bcast {}; |
213 | bool is_binary_po_per_mb_w_bcast {}; |
214 | bool is_binary_po_per_w_bcast {}; |
215 | std::tie(is_binary_po_channel_bcast, is_binary_po_per_mb_w_bcast, |
216 | is_binary_po_per_w_bcast) |
217 | = binary_injector_utils::bcast_strategies_present_tup( |
218 | post_ops.entry_, *dst_d, |
219 | broadcasting_strategy_t::per_mb_spatial, |
220 | broadcasting_strategy_t::per_mb_w, |
221 | broadcasting_strategy_t::per_w); |
222 | const bool supported_binary_bcast |
223 | = IMPLICATION(is_binary_po_channel_bcast, |
224 | utils::one_of(ndims, 3, 4)) |
225 | && IMPLICATION( |
226 | is_binary_po_per_mb_w_bcast, utils::one_of(ndims, 3, 4)) |
227 | && IMPLICATION( |
228 | is_binary_po_per_w_bcast, utils::one_of(ndims, 3, 4)); |
229 | const cpu_isa_t isa = get_max_cpu_isa(); |
230 | return supported_binary_bcast |
231 | && injector::post_ops_ok({isa, {binary, eltwise, sum}, post_ops, |
232 | dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
233 | sum_requires_zp_zero, enabled_bcast_strategy}); |
234 | } |
235 | #endif |
236 | for (size_t i = 0; i < post_ops.entry_.size(); i++) { |
237 | const auto &post_op = post_ops.entry_[i]; |
238 | const bool sum_postop_present = post_op.is_sum(false); |
239 | if (sum_postop_present && i > 0) return false; |
240 | if (!(sum_postop_present || post_op.is_eltwise() |
241 | || post_op.is_binary())) |
242 | return false; |
243 | } |
244 | return true; |
245 | } |
246 | |
247 | bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d, |
248 | const bcast_set_t &enabled_bcast_strategy) { |
249 | const auto dst_md = memory_desc_wrapper(dst_d); |
250 | return post_ops_ok(post_ops, &dst_md, enabled_bcast_strategy); |
251 | } |
252 | |
253 | } // namespace inner_product_utils |
254 | } // namespace cpu |
255 | } // namespace impl |
256 | } // namespace dnnl |
257 | |