1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cmath> |
18 | |
19 | #include "cpu/primitive_attr_postops.hpp" |
20 | #include "cpu/ref_io_helper.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | |
26 | using namespace alg_kind; |
27 | using namespace math; |
28 | |
29 | float compute_binary_scalar(alg_kind_t alg, float x, float y) { |
30 | switch (alg) { |
31 | case binary_add: return x + y; |
32 | case binary_div: return x / y; |
33 | case binary_max: return nstl::max(x, y); |
34 | case binary_min: return nstl::min(x, y); |
35 | case binary_mul: return x * y; |
36 | case binary_sub: return x - y; |
37 | case binary_ge: return x >= y; |
38 | case binary_gt: return x > y; |
39 | case binary_le: return x <= y; |
40 | case binary_lt: return x < y; |
41 | case binary_eq: return x == y; |
42 | case binary_ne: return x != y; |
43 | default: assert(!"not supported operation!" ); return NAN; |
44 | } |
45 | } |
46 | |
47 | float compute_eltwise_scalar_fwd( |
48 | const alg_kind_t alg, float s, float alpha, float beta) { |
49 | float d = 0.f; |
50 | switch (alg) { |
51 | case eltwise_relu: d = relu_fwd(s, alpha); break; |
52 | case eltwise_tanh: d = tanh_fwd(s); break; |
53 | case eltwise_elu: d = elu_fwd(s, alpha); break; |
54 | case eltwise_square: d = square_fwd(s); break; |
55 | case eltwise_abs: d = abs_fwd(s); break; |
56 | case eltwise_sqrt: d = sqrt_fwd(s); break; |
57 | case eltwise_linear: d = linear_fwd(s, alpha, beta); break; |
58 | case eltwise_soft_relu: d = soft_relu_fwd(s, alpha); break; |
59 | case eltwise_logistic: d = logistic_fwd(s); break; |
60 | case eltwise_exp: d = exp_fwd(s); break; |
61 | case eltwise_gelu_tanh: d = gelu_tanh_fwd(s); break; |
62 | case eltwise_swish: d = swish_fwd(s, alpha); break; |
63 | case eltwise_log: d = log_fwd(s); break; |
64 | case eltwise_clip: d = clip_fwd(s, alpha, beta); break; |
65 | case eltwise_clip_v2: d = clip_v2_fwd(s, alpha, beta); break; |
66 | case eltwise_pow: d = pow_fwd(s, alpha, beta); break; |
67 | case eltwise_gelu_erf: d = gelu_erf_fwd(s); break; |
68 | case eltwise_round: d = round_fwd(s); break; |
69 | case eltwise_mish: d = mish_fwd(s); break; |
70 | case eltwise_hardsigmoid: d = hardsigmoid_fwd(s, alpha, beta); break; |
71 | case eltwise_hardswish: d = hardswish_fwd(s, alpha, beta); break; |
72 | case eltwise_relu_use_dst_for_bwd: d = relu_fwd(s, alpha); break; |
73 | case eltwise_tanh_use_dst_for_bwd: d = tanh_fwd(s); break; |
74 | case eltwise_elu_use_dst_for_bwd: d = elu_fwd(s, alpha); break; |
75 | case eltwise_sqrt_use_dst_for_bwd: d = sqrt_fwd(s); break; |
76 | case eltwise_logistic_use_dst_for_bwd: d = logistic_fwd(s); break; |
77 | case eltwise_exp_use_dst_for_bwd: d = exp_fwd(s); break; |
78 | case eltwise_clip_v2_use_dst_for_bwd: |
79 | d = clip_v2_fwd(s, alpha, beta); |
80 | break; |
81 | |
82 | default: assert(!"unknown eltwise alg_kind" ); |
83 | } |
84 | return d; |
85 | } |
86 | |
87 | float compute_eltwise_scalar_bwd( |
88 | const alg_kind_t alg, float dd, float s, float alpha, float beta) { |
89 | float ds = 0.f; |
90 | switch (alg) { |
91 | case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; |
92 | case eltwise_tanh: ds = tanh_bwd(dd, s); break; |
93 | case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; |
94 | case eltwise_square: ds = square_bwd(dd, s); break; |
95 | case eltwise_abs: ds = abs_bwd(dd, s); break; |
96 | case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; |
97 | case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break; |
98 | case eltwise_soft_relu: ds = soft_relu_bwd(dd, s, alpha); break; |
99 | case eltwise_logistic: ds = logistic_bwd(dd, s); break; |
100 | case eltwise_exp: ds = exp_bwd(dd, s); break; |
101 | case eltwise_gelu_tanh: ds = gelu_tanh_bwd(dd, s); break; |
102 | case eltwise_swish: ds = swish_bwd(dd, s, alpha); break; |
103 | case eltwise_log: ds = log_bwd(dd, s); break; |
104 | case eltwise_clip: ds = clip_bwd(dd, s, alpha, beta); break; |
105 | case eltwise_clip_v2: ds = clip_v2_bwd(dd, s, alpha, beta); break; |
106 | case eltwise_pow: ds = pow_bwd(dd, s, alpha, beta); break; |
107 | case eltwise_gelu_erf: ds = gelu_erf_bwd(dd, s); break; |
108 | case eltwise_mish: ds = mish_bwd(dd, s); break; |
109 | case eltwise_hardsigmoid: |
110 | ds = hardsigmoid_bwd(dd, s, alpha, beta); |
111 | break; |
112 | case eltwise_hardswish: ds = hardswish_bwd(dd, s, alpha, beta); break; |
113 | case eltwise_relu_use_dst_for_bwd: |
114 | ds = relu_bwd_use_dst(dd, s, alpha); |
115 | break; |
116 | case eltwise_tanh_use_dst_for_bwd: ds = tanh_bwd_use_dst(dd, s); break; |
117 | case eltwise_elu_use_dst_for_bwd: |
118 | ds = elu_bwd_use_dst(dd, s, alpha); |
119 | break; |
120 | case eltwise_sqrt_use_dst_for_bwd: ds = sqrt_bwd_use_dst(dd, s); break; |
121 | case eltwise_logistic_use_dst_for_bwd: |
122 | ds = logistic_bwd_use_dst(dd, s); |
123 | break; |
124 | case eltwise_exp_use_dst_for_bwd: ds = exp_bwd_use_dst(dd, s); break; |
125 | case eltwise_clip_v2_use_dst_for_bwd: |
126 | ds = clip_v2_bwd_use_dst(dd, s, alpha, beta); |
127 | break; |
128 | |
129 | default: assert(!"unknown eltwise alg_kind" ); |
130 | } |
131 | return ds; |
132 | } |
133 | |
134 | ref_binary_scalar_t::ref_binary_scalar_t(alg_kind_t alg) : alg_(alg) { |
135 | assert(utils::one_of(alg_, alg_kind::binary_add, alg_kind::binary_max, |
136 | alg_kind::binary_min, alg_kind::binary_mul, alg_kind::binary_div, |
137 | alg_kind::binary_sub, alg_kind::binary_ge, alg_kind::binary_gt, |
138 | alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq, |
139 | alg_kind::binary_ne)); |
140 | } |
141 | |
142 | ref_binary_scalar_t::ref_binary_scalar_t( |
143 | const post_ops_t::entry_t::binary_t &binary) |
144 | : ref_binary_scalar_t(binary.alg) {} |
145 | |
146 | float ref_binary_scalar_t::compute_scalar(float src0, float src1) const { |
147 | return compute_binary_scalar(alg_, src0, src1); |
148 | } |
149 | |
150 | ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( |
151 | alg_kind_t alg, float alpha, float beta, float scale) |
152 | : alg_(alg), alpha_(alpha), beta_(beta), scale_(scale) { |
153 | assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, |
154 | eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, |
155 | eltwise_soft_relu, eltwise_mish, eltwise_logistic, eltwise_exp, |
156 | eltwise_gelu_tanh, eltwise_swish, eltwise_log, eltwise_clip, |
157 | eltwise_clip_v2, eltwise_pow, eltwise_gelu_erf, eltwise_round, |
158 | eltwise_hardsigmoid, eltwise_hardswish, |
159 | eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd, |
160 | eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd, |
161 | eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd, |
162 | eltwise_clip_v2_use_dst_for_bwd)); |
163 | } |
164 | |
165 | ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( |
166 | const post_ops_t::entry_t::eltwise_t &eltwise) |
167 | : ref_eltwise_scalar_fwd_t( |
168 | eltwise.alg, eltwise.alpha, eltwise.beta, eltwise.scale) {} |
169 | |
170 | float ref_eltwise_scalar_fwd_t::compute_scalar(float s) const { |
171 | return compute_eltwise_scalar_fwd(alg_, s, alpha_, beta_) * scale_; |
172 | } |
173 | |
174 | ref_post_ops_t::ref_post_ops_t(const post_ops_t &po, bool skip_sum) |
175 | : po_(po), skip_sum_(skip_sum) { |
176 | for (auto idx = 0; idx < po_.len(); ++idx) { |
177 | const auto &e = po_.entry_[idx]; |
178 | if (po_.contain(primitive_kind::eltwise, idx)) { |
179 | eltwise_po_.emplace_back(e.eltwise); |
180 | } else if (po_.contain(primitive_kind::binary, idx)) { |
181 | binary_po_.emplace_back(e.binary); |
182 | } |
183 | } |
184 | } |
185 | |
186 | namespace { |
187 | |
188 | format_tag_t get_prelu_weights_format(const dim_t n_dims) { |
189 | switch (n_dims) { |
190 | case 1: return format_tag::a; |
191 | case 2: return format_tag::ab; |
192 | case 3: return format_tag::acb; |
193 | case 4: return format_tag::acdb; |
194 | case 5: return format_tag::acdeb; |
195 | } |
196 | |
197 | return format_tag::undef; |
198 | } |
199 | |
200 | memory_desc_t get_prelu_memory_desc( |
201 | const dims_t &dst_dims, const int dst_ndims, int weights_mask) { |
202 | |
203 | memory_desc_t weights_md; |
204 | weights_md.data_type = data_type::f32; |
205 | weights_md.ndims = dst_ndims; |
206 | utils::copy_dims_with_mask( |
207 | weights_md.dims, dst_dims, dst_ndims, weights_mask); |
208 | memory_desc_init_by_tag(weights_md, get_prelu_weights_format(dst_ndims)); |
209 | |
210 | return weights_md; |
211 | } |
212 | |
213 | void get_l_dims_po(dims_t &l_dims_po, const dim_t l_offset, |
214 | const dims_t &dst_dims, const int dst_ndims, int mask) { |
215 | utils::l_dims_by_l_offset(l_dims_po, l_offset, dst_dims, dst_ndims); |
216 | utils::apply_mask_on_dims(l_dims_po, dst_ndims, mask); |
217 | } |
218 | |
219 | dim_t get_po_tensor_off(const memory_desc_t &tensor_md, const dim_t l_offset, |
220 | const dims_t &dst_dims, const int dst_ndims, int mask) { |
221 | |
222 | dims_t l_dims_po {}; |
223 | get_l_dims_po(l_dims_po, l_offset, dst_dims, dst_ndims, mask); |
224 | |
225 | return memory_desc_wrapper(tensor_md).off_v(l_dims_po); |
226 | } |
227 | |
228 | dim_t get_prelu_weights_off(const dim_t l_offset, const dims_t &dst_dims, |
229 | const int dst_ndims, int weights_mask) { |
230 | |
231 | const memory_desc_t &weights_md |
232 | = get_prelu_memory_desc(dst_dims, dst_ndims, weights_mask); |
233 | |
234 | return get_po_tensor_off( |
235 | weights_md, l_offset, dst_dims, dst_ndims, weights_mask); |
236 | } |
237 | |
238 | dim_t get_binary_src1_off(const memory_desc_t &src1_md, const dim_t l_offset, |
239 | const dims_t &dst_dims, const int dst_ndims) { |
240 | |
241 | const int mask_binary_po |
242 | = utils::get_dims_mask(dst_dims, src1_md.dims, dst_ndims); |
243 | |
244 | return get_po_tensor_off( |
245 | src1_md, l_offset, dst_dims, dst_ndims, mask_binary_po); |
246 | } |
247 | |
248 | } // namespace |
249 | |
250 | status_t ref_post_ops_t::execute(float &res, const args_t &args) const { |
251 | if (po_.len() == 0) return status::success; |
252 | |
253 | auto it_eltwise_po = eltwise_po_.begin(); |
254 | auto it_binary_po = binary_po_.begin(); |
255 | for (auto idx = 0; idx < po_.len(); ++idx) { |
256 | const auto &e = po_.entry_[idx]; |
257 | switch (e.kind) { |
258 | case primitive_kind::sum: |
259 | if (!skip_sum_) { |
260 | res += e.sum.scale * (args.dst_val - e.sum.zero_point); |
261 | } |
262 | break; |
263 | case primitive_kind::eltwise: |
264 | res = it_eltwise_po->compute_scalar(res); |
265 | it_eltwise_po++; |
266 | break; |
267 | case primitive_kind::binary: { |
268 | assert(args.ctx); |
269 | assert(args.l_offset >= 0); |
270 | assert(args.dst_md); |
271 | |
272 | const exec_ctx_t &ctx = *args.ctx; |
273 | const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, args.dst_md); |
274 | const auto &src1_desc = e.binary.src1_desc; |
275 | |
276 | const auto off = get_binary_src1_off( |
277 | src1_desc, args.l_offset, dst_d.dims(), dst_d.ndims()); |
278 | const auto src1_binary_po = CTX_IN_MEM(const void *, |
279 | (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); |
280 | const float val_po = io::load_float_value( |
281 | src1_desc.data_type, src1_binary_po, off); |
282 | res = it_binary_po->compute_scalar(res, val_po); |
283 | ++it_binary_po; |
284 | } break; |
285 | case primitive_kind::prelu: { |
286 | if (res >= 0) break; |
287 | |
288 | assert(args.ctx); |
289 | assert(args.l_offset >= 0); |
290 | assert(args.dst_md); |
291 | |
292 | const exec_ctx_t &ctx = *args.ctx; |
293 | const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, args.dst_md); |
294 | const auto prelu_weights = CTX_IN_MEM(const float *, |
295 | (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
296 | | DNNL_ARG_WEIGHTS)); |
297 | const auto off = get_prelu_weights_off(args.l_offset, |
298 | dst_d.dims(), dst_d.ndims(), e.prelu.mask); |
299 | const auto &weights_value = prelu_weights[off]; |
300 | res = weights_value * res; |
301 | } break; |
302 | default: assert(!"unsupported post op primitive kind!" ); |
303 | } |
304 | } |
305 | return status::success; |
306 | } |
307 | |
308 | } // namespace cpu |
309 | } // namespace impl |
310 | } // namespace dnnl |
311 | |