1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_MATH_UTILS_HPP
18#define COMMON_MATH_UTILS_HPP
19
20#include <math.h>
21#include <stdint.h>
22
23#include "dnnl_traits.hpp"
24#include "nstl.hpp"
25#include "utils.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace math {
30
31inline int gcd(int a, int b) {
32 a = impl::nstl::abs(a);
33 b = impl::nstl::abs(b);
34 if (a < b) {
35 int x = a;
36 a = b;
37 b = x;
38 }
39
40 if (b == 0) return a;
41
42 int r;
43 while ((r = a % b) != 0) {
44 a = b;
45 b = r;
46 }
47
48 return b;
49}
50
51inline int lcm(int a, int b) {
52 a = impl::nstl::abs(a);
53 b = impl::nstl::abs(b);
54 assert(a > 0 && b > 0);
55
56 return a * b / gcd(a, b);
57}
58
59template <typename T>
60inline bool is_pow2(const T &v) {
61 return (v > 0) && ((v & (v - 1)) == 0);
62}
63
64/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
65inline int ilog2q(size_t v) {
66 if (v == 0) return -1;
67
68 int p = 0;
69#define CP(pw) \
70 do { \
71 if (v >= (1ull << pw)) { \
72 v >>= pw; \
73 p += pw; \
74 } \
75 } while (0)
76 CP(32);
77 CP(16);
78 CP(8);
79 CP(4);
80 CP(2);
81 CP(1);
82#undef CP
83 return p;
84}
85
86template <typename T, typename U = typename utils::remove_reference<T>::type>
87inline U one_m_square(T x) {
88 return (U)(1 - x) * (1 + x);
89}
90
91template <typename T, typename U = typename utils::remove_reference<T>::type>
92inline U x_m_square(T x) {
93 return (U)(1 - x) * x;
94}
95
96/* activation */
97
98/** rounds @p f to an integer according to the mxcsr register */
99inline float mxcsr_round(float f) ATTR_NO_MSAN {
100 return nearbyintf(f);
101}
102
103/** converts @p f to an integer according to the mxcsr register */
104inline int mxcsr_cvt(float f) ATTR_NO_MSAN {
105 return (int)mxcsr_round(f);
106}
107
108inline float round_fwd(float s) {
109 return mxcsr_round(s);
110}
111
112template <typename T, typename A,
113 typename U = typename utils::remove_reference<T>::type>
114inline typename utils::enable_if<nstl::is_integral<U>::value, U>::type relu_fwd(
115 T s, A alpha) {
116 return s > 0 ? s : (U)mxcsr_cvt(static_cast<float>(s * alpha));
117}
118
119template <typename T, typename A,
120 typename U = typename utils::remove_reference<T>::type>
121inline typename utils::enable_if<!nstl::is_integral<U>::value, U>::type
122relu_fwd(T s, A alpha) ATTR_NO_MSAN {
123 return s > 0 ? s : (U)(s * alpha);
124}
125
126template <typename T, typename A,
127 typename U = typename utils::remove_reference<T>::type>
128inline U relu_bwd(T dd, T s, A alpha) {
129 return s > 0 ? dd : (U)(dd * alpha);
130}
131template <typename T, typename A,
132 typename U = typename utils::remove_reference<T>::type>
133inline U relu_bwd(T s, A alpha) {
134 return s > 0 ? (U)1 : (U)alpha;
135}
136template <typename T, typename A,
137 typename U = typename utils::remove_reference<T>::type>
138inline U relu_bwd_use_dst(T dd, T d, A alpha) {
139 return d > 0 ? dd : (U)(dd * alpha);
140}
141
142template <typename T, typename U = typename utils::remove_reference<T>::type>
143inline U tanh_fwd(T s) {
144 const float e = tanhf((float)s);
145 return (U)e;
146}
147template <typename T, typename U = typename utils::remove_reference<T>::type>
148inline U tanh_bwd(T dd, T s) {
149 const float e = tanh_fwd<float>((float)s);
150 return (U)(dd * (1 - e) * (1 + e));
151}
152template <typename T, typename U = typename utils::remove_reference<T>::type>
153inline U tanh_bwd_use_dst(T dd, T d) {
154 return (U)(dd * (1 - d) * (1 + d));
155}
156
157template <typename T, typename A,
158 typename U = typename utils::remove_reference<T>::type>
159inline U elu_fwd(T s, A alpha) {
160 return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
161}
162template <typename T, typename A,
163 typename U = typename utils::remove_reference<T>::type>
164inline U elu_bwd(T dd, T s, A alpha) {
165 return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
166}
167template <typename T, typename A,
168 typename U = typename utils::remove_reference<T>::type>
169inline U elu_bwd_use_dst(T dd, T d, A alpha) {
170 return (U)(dd * (d > 0 ? 1 : d + alpha));
171}
172
173template <typename T, typename U = typename utils::remove_reference<T>::type>
174inline U square_fwd(T s) {
175 return s * s;
176}
177template <typename T, typename U = typename utils::remove_reference<T>::type>
178inline U square_bwd(T dd, T s) {
179 return dd * 2 * s;
180}
181
182template <typename T, typename U = typename utils::remove_reference<T>::type>
183inline U abs_fwd(T s) {
184 return s > 0 ? s : (U)-s;
185}
186template <typename T, typename U = typename utils::remove_reference<T>::type>
187inline U abs_bwd(T dd, T s) {
188 return s > 0 ? dd : s < 0 ? (U)-dd : (U)0;
189}
190
191template <typename T, typename U = typename utils::remove_reference<T>::type>
192inline U sqrt_fwd(T s) {
193 return (U)(::sqrtf((float)(s)));
194}
195template <typename T, typename U = typename utils::remove_reference<T>::type>
196inline U sqrt_bwd(T dd, T s) {
197 return (U)(dd / (2 * ::sqrtf((float)(s))));
198}
199template <typename T, typename U = typename utils::remove_reference<T>::type>
200inline U sqrt_bwd_use_dst(T dd, T d) {
201 return (U)(dd / (2 * d));
202}
203
204template <typename T, typename A,
205 typename U = typename utils::remove_reference<T>::type>
206inline U linear_fwd(T s, A alpha, A beta) {
207 return (U)(alpha * s + beta);
208}
209template <typename T, typename A,
210 typename U = typename utils::remove_reference<T>::type>
211inline U linear_bwd(T dd, T s, A alpha, A beta) {
212 (void)s;
213 (void)beta;
214 return (U)(dd * alpha);
215}
216
217template <typename T, typename U = typename utils::remove_reference<T>::type>
218inline U logistic_fwd(T s) {
219 // Here we avoid division/inverse by infinity as some architectures have
220 // non-standard behavior
221 float exp_overflow_bound = 88.72283172607421875;
222 float in = (float)-s;
223 return in < exp_overflow_bound ? (U)(1.f / (1.f + ::expf(in))) : 0.f;
224}
225template <typename T, typename U = typename utils::remove_reference<T>::type>
226inline U logistic_bwd(T dd, T s) {
227 float v = logistic_fwd<float>(s);
228 return (U)(dd * v * (1 - v));
229}
230template <typename T, typename U = typename utils::remove_reference<T>::type>
231inline U logistic_bwd_use_dst(T dd, T d) {
232 return (U)(dd * d * (1 - d));
233}
234
235template <typename T, typename A,
236 typename U = typename utils::remove_reference<T>::type>
237inline U soft_relu_fwd(T s, A alpha) {
238 float exp_overflow_bound = 88.72283172607421875;
239 float in = (float)s * (float)alpha;
240 float v = (in < exp_overflow_bound ? (U)(::log1pf(::expf(in))) : (U)in);
241 return (U)(v / alpha);
242}
243template <typename T, typename A,
244 typename U = typename utils::remove_reference<T>::type>
245inline U soft_relu_bwd(T dd, T s, A alpha) {
246 float in = (float)s * (float)alpha;
247 return (U)(dd * logistic_fwd<float>(in));
248}
249
250template <typename T, typename U = typename utils::remove_reference<T>::type>
251inline U mish_fwd(T s) {
252 return s * tanh_fwd(soft_relu_fwd(s, 1.f));
253}
254template <typename T, typename U = typename utils::remove_reference<T>::type>
255inline U mish_bwd(T dd, T s) {
256 const float tanh = tanh_fwd(soft_relu_fwd(s, 1.f));
257 const float srelu_bwd = soft_relu_bwd(1.f, s, 1.f);
258 const float derivative = tanh + s * srelu_bwd * (1 - ::powf(tanh, 2.0f));
259 return dd * derivative;
260}
261
262template <typename T, typename A,
263 typename U = typename utils::remove_reference<T>::type>
264inline U swish_fwd(T s, A alpha) {
265 return (U)(s * logistic_fwd<float>(alpha * s));
266}
267template <typename T, typename A,
268 typename U = typename utils::remove_reference<T>::type>
269inline U swish_bwd(T dd, T s, A alpha) {
270 float v = logistic_fwd<float>(alpha * s);
271 return dd * (v + s * alpha * v * (1 - v));
272}
273
274template <typename T, typename U = typename utils::remove_reference<T>::type>
275inline U exp_fwd(T s) {
276 return (U)(::expf((float)s));
277}
278template <typename T, typename U = typename utils::remove_reference<T>::type>
279inline U exp_bwd(T dd, T s) {
280 return (U)(dd * (::expf((float)s)));
281}
282template <typename T, typename U = typename utils::remove_reference<T>::type>
283inline U exp_bwd_use_dst(T dd, T d) {
284 return (U)(dd * d);
285}
286
287template <typename T, typename U = typename utils::remove_reference<T>::type>
288inline U gelu_tanh_fwd(T s) {
289 const float sqrt_2_over_pi = 0.79788458347320556640625f;
290 const float fitting_const = 0.044715f;
291 float v = tanh_fwd(sqrt_2_over_pi * s * (1 + fitting_const * s * s));
292 return (U)(0.5 * s * (1. + v));
293}
294template <typename T, typename U = typename utils::remove_reference<T>::type>
295inline U gelu_tanh_bwd(T dd, T s) {
296 const float sqrt_2_over_pi = 0.79788458347320556640625f;
297 const float fitting_const = 0.044715f;
298 float g = s * sqrt_2_over_pi * (1 + fitting_const * s * s);
299 float dg = sqrt_2_over_pi * (1 + 3 * fitting_const * s * s);
300 float v = tanh_fwd(g);
301 return (U)(dd * 0.5 * (1. + v) * (1. + s * (1 - v) * dg));
302}
303
304template <typename T, typename U = typename utils::remove_reference<T>::type>
305inline U log_fwd(T s) {
306 return (U)(::logf((float)s));
307}
308template <typename T, typename U = typename utils::remove_reference<T>::type>
309inline U log_bwd(T dd, T s) {
310 return (U)(dd * (1.f / (float)s));
311}
312
313template <typename T, typename A,
314 typename U = typename utils::remove_reference<T>::type>
315inline U clip_fwd(T s, A alpha, A beta) {
316 s = s > alpha ? s : (U)alpha;
317 return s > beta ? (U)beta : s;
318}
319template <typename T, typename A,
320 typename U = typename utils::remove_reference<T>::type>
321inline U clip_bwd(T dd, T s, A alpha, A beta) {
322 return dd * (alpha < s && s <= beta ? 1 : 0);
323}
324
325template <typename T, typename A,
326 typename U = typename utils::remove_reference<T>::type>
327inline U clip_v2_fwd(T s, A alpha, A beta) {
328 s = s > alpha ? s : (U)alpha;
329 return s < beta ? s : (U)beta;
330}
331template <typename T, typename A,
332 typename U = typename utils::remove_reference<T>::type>
333inline U clip_v2_bwd(T dd, T s, A alpha, A beta) {
334 return dd * (alpha < s && s < beta ? 1 : 0);
335}
336template <typename T, typename A,
337 typename U = typename utils::remove_reference<T>::type>
338inline U clip_v2_bwd_use_dst(T dd, T d, A alpha, A beta) {
339 return clip_v2_bwd(dd, d, alpha, beta);
340}
341
342template <typename T, typename A,
343 typename U = typename utils::remove_reference<T>::type>
344inline U pow_fwd(T s, A alpha, A beta) {
345 return (U)(alpha * ::powf((float)s, beta));
346}
347template <typename T, typename A,
348 typename U = typename utils::remove_reference<T>::type>
349inline U pow_bwd(T dd, T s, A alpha, A beta) {
350 if (beta == 0) return 0;
351
352 float v = pow_fwd(s, alpha * beta, beta - 1);
353 return (U)(dd * v);
354}
355
356template <typename T, typename U = typename utils::remove_reference<T>::type>
357inline U gelu_erf_fwd(T s) {
358 const float sqrt_2_over_2 = 0.707106769084930419921875f;
359 float v = s * sqrt_2_over_2;
360 return (U)(0.5f * s * (1.f + ::erff(v)));
361}
362template <typename T, typename U = typename utils::remove_reference<T>::type>
363inline U gelu_erf_bwd(T dd, T s) {
364 const float two_over_sqrt_pi = 1.12837922573089599609375f;
365 const float sqrt_2_over_2 = 0.707106769084930419921875f;
366 float v = s * sqrt_2_over_2;
367 return (U)(dd * 0.5f
368 * (1.f + ::erff(v) + v * two_over_sqrt_pi * ::expf(-v * v)));
369}
370
371template <typename T, typename A,
372 typename U = typename utils::remove_reference<T>::type>
373inline U hardsigmoid_fwd(T s, A alpha, A beta) {
374 float v = alpha * s + beta;
375 return v <= 0.f ? 0.f : v >= 1.f ? 1.f : v;
376}
377template <typename T, typename A,
378 typename U = typename utils::remove_reference<T>::type>
379inline U hardsigmoid_bwd(T dd, T s, A alpha, A beta) {
380 float v = alpha * s + beta;
381 return v <= 0.f ? 0.f : v >= 1.f ? 0.f : dd * alpha;
382}
383
384template <typename T, typename A,
385 typename U = typename utils::remove_reference<T>::type>
386inline U hardswish_fwd(T s, A alpha, A beta) {
387 return (U)(s * hardsigmoid_fwd(s, alpha, beta));
388}
389template <typename T, typename A,
390 typename U = typename utils::remove_reference<T>::type>
391inline U hardswish_bwd(T dd, T s, A alpha, A beta) {
392 float v = alpha * s + beta;
393 float w = 2.f * alpha * s + beta;
394 return v <= 0.f ? 0.f : v >= 1.f ? dd : dd * w;
395}
396
397inline bool is_eltwise_ok(
398 data_type_t src_dt, alg_kind_t alg, float alpha, float beta) {
399 using namespace alg_kind;
400 using namespace utils;
401
402 const bool eltwise_use_src
403 = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
404 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
405 eltwise_soft_relu, eltwise_mish, eltwise_logistic,
406 eltwise_exp, eltwise_gelu_tanh, eltwise_hardsigmoid,
407 eltwise_hardswish, eltwise_swish, eltwise_log,
408 eltwise_clip, eltwise_clip_v2, eltwise_pow,
409 eltwise_gelu_erf, eltwise_round)
410 && IMPLICATION(
411 one_of(alg, eltwise_clip, eltwise_clip_v2), beta >= alpha)
412 && IMPLICATION(alg == eltwise_round, src_dt == dnnl_f32)
413 && IMPLICATION(one_of(src_dt, dnnl_s32, dnnl_s8, dnnl_u8),
414 one_of(alg, eltwise_relu, eltwise_linear));
415
416 const bool eltwise_use_dst
417 = one_of(alg, eltwise_relu_use_dst_for_bwd,
418 eltwise_tanh_use_dst_for_bwd, eltwise_elu_use_dst_for_bwd,
419 eltwise_sqrt_use_dst_for_bwd,
420 eltwise_logistic_use_dst_for_bwd,
421 eltwise_exp_use_dst_for_bwd,
422 eltwise_clip_v2_use_dst_for_bwd)
423 && IMPLICATION(one_of(alg, eltwise_relu_use_dst_for_bwd,
424 eltwise_elu_use_dst_for_bwd),
425 alpha >= 0)
426 && IMPLICATION(
427 alg == eltwise_clip_v2_use_dst_for_bwd, beta >= alpha);
428
429 return eltwise_use_src || eltwise_use_dst;
430}
431
432} // namespace math
433} // namespace impl
434} // namespace dnnl
435
436#endif
437