1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cmath>
18
19#include "cpu/primitive_attr_postops.hpp"
20#include "cpu/ref_io_helper.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25
26using namespace alg_kind;
27using namespace math;
28
29float compute_binary_scalar(alg_kind_t alg, float x, float y) {
30 switch (alg) {
31 case binary_add: return x + y;
32 case binary_div: return x / y;
33 case binary_max: return nstl::max(x, y);
34 case binary_min: return nstl::min(x, y);
35 case binary_mul: return x * y;
36 case binary_sub: return x - y;
37 case binary_ge: return x >= y;
38 case binary_gt: return x > y;
39 case binary_le: return x <= y;
40 case binary_lt: return x < y;
41 case binary_eq: return x == y;
42 case binary_ne: return x != y;
43 default: assert(!"not supported operation!"); return NAN;
44 }
45}
46
47float compute_eltwise_scalar_fwd(
48 const alg_kind_t alg, float s, float alpha, float beta) {
49 float d = 0.f;
50 switch (alg) {
51 case eltwise_relu: d = relu_fwd(s, alpha); break;
52 case eltwise_tanh: d = tanh_fwd(s); break;
53 case eltwise_elu: d = elu_fwd(s, alpha); break;
54 case eltwise_square: d = square_fwd(s); break;
55 case eltwise_abs: d = abs_fwd(s); break;
56 case eltwise_sqrt: d = sqrt_fwd(s); break;
57 case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
58 case eltwise_soft_relu: d = soft_relu_fwd(s, alpha); break;
59 case eltwise_logistic: d = logistic_fwd(s); break;
60 case eltwise_exp: d = exp_fwd(s); break;
61 case eltwise_gelu_tanh: d = gelu_tanh_fwd(s); break;
62 case eltwise_swish: d = swish_fwd(s, alpha); break;
63 case eltwise_log: d = log_fwd(s); break;
64 case eltwise_clip: d = clip_fwd(s, alpha, beta); break;
65 case eltwise_clip_v2: d = clip_v2_fwd(s, alpha, beta); break;
66 case eltwise_pow: d = pow_fwd(s, alpha, beta); break;
67 case eltwise_gelu_erf: d = gelu_erf_fwd(s); break;
68 case eltwise_round: d = round_fwd(s); break;
69 case eltwise_mish: d = mish_fwd(s); break;
70 case eltwise_hardsigmoid: d = hardsigmoid_fwd(s, alpha, beta); break;
71 case eltwise_hardswish: d = hardswish_fwd(s, alpha, beta); break;
72 case eltwise_relu_use_dst_for_bwd: d = relu_fwd(s, alpha); break;
73 case eltwise_tanh_use_dst_for_bwd: d = tanh_fwd(s); break;
74 case eltwise_elu_use_dst_for_bwd: d = elu_fwd(s, alpha); break;
75 case eltwise_sqrt_use_dst_for_bwd: d = sqrt_fwd(s); break;
76 case eltwise_logistic_use_dst_for_bwd: d = logistic_fwd(s); break;
77 case eltwise_exp_use_dst_for_bwd: d = exp_fwd(s); break;
78 case eltwise_clip_v2_use_dst_for_bwd:
79 d = clip_v2_fwd(s, alpha, beta);
80 break;
81
82 default: assert(!"unknown eltwise alg_kind");
83 }
84 return d;
85}
86
87float compute_eltwise_scalar_bwd(
88 const alg_kind_t alg, float dd, float s, float alpha, float beta) {
89 float ds = 0.f;
90 switch (alg) {
91 case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
92 case eltwise_tanh: ds = tanh_bwd(dd, s); break;
93 case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
94 case eltwise_square: ds = square_bwd(dd, s); break;
95 case eltwise_abs: ds = abs_bwd(dd, s); break;
96 case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
97 case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break;
98 case eltwise_soft_relu: ds = soft_relu_bwd(dd, s, alpha); break;
99 case eltwise_logistic: ds = logistic_bwd(dd, s); break;
100 case eltwise_exp: ds = exp_bwd(dd, s); break;
101 case eltwise_gelu_tanh: ds = gelu_tanh_bwd(dd, s); break;
102 case eltwise_swish: ds = swish_bwd(dd, s, alpha); break;
103 case eltwise_log: ds = log_bwd(dd, s); break;
104 case eltwise_clip: ds = clip_bwd(dd, s, alpha, beta); break;
105 case eltwise_clip_v2: ds = clip_v2_bwd(dd, s, alpha, beta); break;
106 case eltwise_pow: ds = pow_bwd(dd, s, alpha, beta); break;
107 case eltwise_gelu_erf: ds = gelu_erf_bwd(dd, s); break;
108 case eltwise_mish: ds = mish_bwd(dd, s); break;
109 case eltwise_hardsigmoid:
110 ds = hardsigmoid_bwd(dd, s, alpha, beta);
111 break;
112 case eltwise_hardswish: ds = hardswish_bwd(dd, s, alpha, beta); break;
113 case eltwise_relu_use_dst_for_bwd:
114 ds = relu_bwd_use_dst(dd, s, alpha);
115 break;
116 case eltwise_tanh_use_dst_for_bwd: ds = tanh_bwd_use_dst(dd, s); break;
117 case eltwise_elu_use_dst_for_bwd:
118 ds = elu_bwd_use_dst(dd, s, alpha);
119 break;
120 case eltwise_sqrt_use_dst_for_bwd: ds = sqrt_bwd_use_dst(dd, s); break;
121 case eltwise_logistic_use_dst_for_bwd:
122 ds = logistic_bwd_use_dst(dd, s);
123 break;
124 case eltwise_exp_use_dst_for_bwd: ds = exp_bwd_use_dst(dd, s); break;
125 case eltwise_clip_v2_use_dst_for_bwd:
126 ds = clip_v2_bwd_use_dst(dd, s, alpha, beta);
127 break;
128
129 default: assert(!"unknown eltwise alg_kind");
130 }
131 return ds;
132}
133
134ref_binary_scalar_t::ref_binary_scalar_t(alg_kind_t alg) : alg_(alg) {
135 assert(utils::one_of(alg_, alg_kind::binary_add, alg_kind::binary_max,
136 alg_kind::binary_min, alg_kind::binary_mul, alg_kind::binary_div,
137 alg_kind::binary_sub, alg_kind::binary_ge, alg_kind::binary_gt,
138 alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq,
139 alg_kind::binary_ne));
140}
141
142ref_binary_scalar_t::ref_binary_scalar_t(
143 const post_ops_t::entry_t::binary_t &binary)
144 : ref_binary_scalar_t(binary.alg) {}
145
146float ref_binary_scalar_t::compute_scalar(float src0, float src1) const {
147 return compute_binary_scalar(alg_, src0, src1);
148}
149
150ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
151 alg_kind_t alg, float alpha, float beta, float scale)
152 : alg_(alg), alpha_(alpha), beta_(beta), scale_(scale) {
153 assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
154 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
155 eltwise_soft_relu, eltwise_mish, eltwise_logistic, eltwise_exp,
156 eltwise_gelu_tanh, eltwise_swish, eltwise_log, eltwise_clip,
157 eltwise_clip_v2, eltwise_pow, eltwise_gelu_erf, eltwise_round,
158 eltwise_hardsigmoid, eltwise_hardswish,
159 eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd,
160 eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd,
161 eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd,
162 eltwise_clip_v2_use_dst_for_bwd));
163}
164
165ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
166 const post_ops_t::entry_t::eltwise_t &eltwise)
167 : ref_eltwise_scalar_fwd_t(
168 eltwise.alg, eltwise.alpha, eltwise.beta, eltwise.scale) {}
169
170float ref_eltwise_scalar_fwd_t::compute_scalar(float s) const {
171 return compute_eltwise_scalar_fwd(alg_, s, alpha_, beta_) * scale_;
172}
173
174ref_post_ops_t::ref_post_ops_t(const post_ops_t &po, bool skip_sum)
175 : po_(po), skip_sum_(skip_sum) {
176 for (auto idx = 0; idx < po_.len(); ++idx) {
177 const auto &e = po_.entry_[idx];
178 if (po_.contain(primitive_kind::eltwise, idx)) {
179 eltwise_po_.emplace_back(e.eltwise);
180 } else if (po_.contain(primitive_kind::binary, idx)) {
181 binary_po_.emplace_back(e.binary);
182 }
183 }
184}
185
186namespace {
187
188format_tag_t get_prelu_weights_format(const dim_t n_dims) {
189 switch (n_dims) {
190 case 1: return format_tag::a;
191 case 2: return format_tag::ab;
192 case 3: return format_tag::acb;
193 case 4: return format_tag::acdb;
194 case 5: return format_tag::acdeb;
195 }
196
197 return format_tag::undef;
198}
199
200memory_desc_t get_prelu_memory_desc(
201 const dims_t &dst_dims, const int dst_ndims, int weights_mask) {
202
203 memory_desc_t weights_md;
204 weights_md.data_type = data_type::f32;
205 weights_md.ndims = dst_ndims;
206 utils::copy_dims_with_mask(
207 weights_md.dims, dst_dims, dst_ndims, weights_mask);
208 memory_desc_init_by_tag(weights_md, get_prelu_weights_format(dst_ndims));
209
210 return weights_md;
211}
212
213void get_l_dims_po(dims_t &l_dims_po, const dim_t l_offset,
214 const dims_t &dst_dims, const int dst_ndims, int mask) {
215 utils::l_dims_by_l_offset(l_dims_po, l_offset, dst_dims, dst_ndims);
216 utils::apply_mask_on_dims(l_dims_po, dst_ndims, mask);
217}
218
219dim_t get_po_tensor_off(const memory_desc_t &tensor_md, const dim_t l_offset,
220 const dims_t &dst_dims, const int dst_ndims, int mask) {
221
222 dims_t l_dims_po {};
223 get_l_dims_po(l_dims_po, l_offset, dst_dims, dst_ndims, mask);
224
225 return memory_desc_wrapper(tensor_md).off_v(l_dims_po);
226}
227
228dim_t get_prelu_weights_off(const dim_t l_offset, const dims_t &dst_dims,
229 const int dst_ndims, int weights_mask) {
230
231 const memory_desc_t &weights_md
232 = get_prelu_memory_desc(dst_dims, dst_ndims, weights_mask);
233
234 return get_po_tensor_off(
235 weights_md, l_offset, dst_dims, dst_ndims, weights_mask);
236}
237
238dim_t get_binary_src1_off(const memory_desc_t &src1_md, const dim_t l_offset,
239 const dims_t &dst_dims, const int dst_ndims) {
240
241 const int mask_binary_po
242 = utils::get_dims_mask(dst_dims, src1_md.dims, dst_ndims);
243
244 return get_po_tensor_off(
245 src1_md, l_offset, dst_dims, dst_ndims, mask_binary_po);
246}
247
248} // namespace
249
250status_t ref_post_ops_t::execute(float &res, const args_t &args) const {
251 if (po_.len() == 0) return status::success;
252
253 auto it_eltwise_po = eltwise_po_.begin();
254 auto it_binary_po = binary_po_.begin();
255 for (auto idx = 0; idx < po_.len(); ++idx) {
256 const auto &e = po_.entry_[idx];
257 switch (e.kind) {
258 case primitive_kind::sum:
259 if (!skip_sum_) {
260 res += e.sum.scale * (args.dst_val - e.sum.zero_point);
261 }
262 break;
263 case primitive_kind::eltwise:
264 res = it_eltwise_po->compute_scalar(res);
265 it_eltwise_po++;
266 break;
267 case primitive_kind::binary: {
268 assert(args.ctx);
269 assert(args.l_offset >= 0);
270 assert(args.dst_md);
271
272 const exec_ctx_t &ctx = *args.ctx;
273 const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, args.dst_md);
274 const auto &src1_desc = e.binary.src1_desc;
275
276 const auto off = get_binary_src1_off(
277 src1_desc, args.l_offset, dst_d.dims(), dst_d.ndims());
278 const auto src1_binary_po = CTX_IN_MEM(const void *,
279 (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1));
280 const float val_po = io::load_float_value(
281 src1_desc.data_type, src1_binary_po, off);
282 res = it_binary_po->compute_scalar(res, val_po);
283 ++it_binary_po;
284 } break;
285 case primitive_kind::prelu: {
286 if (res >= 0) break;
287
288 assert(args.ctx);
289 assert(args.l_offset >= 0);
290 assert(args.dst_md);
291
292 const exec_ctx_t &ctx = *args.ctx;
293 const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, args.dst_md);
294 const auto prelu_weights = CTX_IN_MEM(const float *,
295 (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
296 | DNNL_ARG_WEIGHTS));
297 const auto off = get_prelu_weights_off(args.l_offset,
298 dst_d.dims(), dst_d.ndims(), e.prelu.mask);
299 const auto &weights_value = prelu_weights[off];
300 res = weights_value * res;
301 } break;
302 default: assert(!"unsupported post op primitive kind!");
303 }
304 }
305 return status::success;
306}
307
308} // namespace cpu
309} // namespace impl
310} // namespace dnnl
311