1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_MATH_UTILS_HPP |
18 | #define COMMON_MATH_UTILS_HPP |
19 | |
20 | #include <math.h> |
21 | #include <stdint.h> |
22 | |
23 | #include "dnnl_traits.hpp" |
24 | #include "nstl.hpp" |
25 | #include "utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace math { |
30 | |
31 | inline int gcd(int a, int b) { |
32 | a = impl::nstl::abs(a); |
33 | b = impl::nstl::abs(b); |
34 | if (a < b) { |
35 | int x = a; |
36 | a = b; |
37 | b = x; |
38 | } |
39 | |
40 | if (b == 0) return a; |
41 | |
42 | int r; |
43 | while ((r = a % b) != 0) { |
44 | a = b; |
45 | b = r; |
46 | } |
47 | |
48 | return b; |
49 | } |
50 | |
51 | inline int lcm(int a, int b) { |
52 | a = impl::nstl::abs(a); |
53 | b = impl::nstl::abs(b); |
54 | assert(a > 0 && b > 0); |
55 | |
56 | return a * b / gcd(a, b); |
57 | } |
58 | |
59 | template <typename T> |
60 | inline bool is_pow2(const T &v) { |
61 | return (v > 0) && ((v & (v - 1)) == 0); |
62 | } |
63 | |
64 | /** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ |
65 | inline int ilog2q(size_t v) { |
66 | if (v == 0) return -1; |
67 | |
68 | int p = 0; |
69 | #define CP(pw) \ |
70 | do { \ |
71 | if (v >= (1ull << pw)) { \ |
72 | v >>= pw; \ |
73 | p += pw; \ |
74 | } \ |
75 | } while (0) |
76 | CP(32); |
77 | CP(16); |
78 | CP(8); |
79 | CP(4); |
80 | CP(2); |
81 | CP(1); |
82 | #undef CP |
83 | return p; |
84 | } |
85 | |
86 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
87 | inline U one_m_square(T x) { |
88 | return (U)(1 - x) * (1 + x); |
89 | } |
90 | |
91 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
92 | inline U x_m_square(T x) { |
93 | return (U)(1 - x) * x; |
94 | } |
95 | |
96 | /* activation */ |
97 | |
98 | /** rounds @p f to an integer according to the mxcsr register */ |
99 | inline float mxcsr_round(float f) ATTR_NO_MSAN { |
100 | return nearbyintf(f); |
101 | } |
102 | |
103 | /** converts @p f to an integer according to the mxcsr register */ |
104 | inline int mxcsr_cvt(float f) ATTR_NO_MSAN { |
105 | return (int)mxcsr_round(f); |
106 | } |
107 | |
108 | inline float round_fwd(float s) { |
109 | return mxcsr_round(s); |
110 | } |
111 | |
112 | template <typename T, typename A, |
113 | typename U = typename utils::remove_reference<T>::type> |
114 | inline typename utils::enable_if<nstl::is_integral<U>::value, U>::type relu_fwd( |
115 | T s, A alpha) { |
116 | return s > 0 ? s : (U)mxcsr_cvt(static_cast<float>(s * alpha)); |
117 | } |
118 | |
119 | template <typename T, typename A, |
120 | typename U = typename utils::remove_reference<T>::type> |
121 | inline typename utils::enable_if<!nstl::is_integral<U>::value, U>::type |
122 | relu_fwd(T s, A alpha) ATTR_NO_MSAN { |
123 | return s > 0 ? s : (U)(s * alpha); |
124 | } |
125 | |
126 | template <typename T, typename A, |
127 | typename U = typename utils::remove_reference<T>::type> |
128 | inline U relu_bwd(T dd, T s, A alpha) { |
129 | return s > 0 ? dd : (U)(dd * alpha); |
130 | } |
131 | template <typename T, typename A, |
132 | typename U = typename utils::remove_reference<T>::type> |
133 | inline U relu_bwd(T s, A alpha) { |
134 | return s > 0 ? (U)1 : (U)alpha; |
135 | } |
136 | template <typename T, typename A, |
137 | typename U = typename utils::remove_reference<T>::type> |
138 | inline U relu_bwd_use_dst(T dd, T d, A alpha) { |
139 | return d > 0 ? dd : (U)(dd * alpha); |
140 | } |
141 | |
142 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
143 | inline U tanh_fwd(T s) { |
144 | const float e = tanhf((float)s); |
145 | return (U)e; |
146 | } |
147 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
148 | inline U tanh_bwd(T dd, T s) { |
149 | const float e = tanh_fwd<float>((float)s); |
150 | return (U)(dd * (1 - e) * (1 + e)); |
151 | } |
152 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
153 | inline U tanh_bwd_use_dst(T dd, T d) { |
154 | return (U)(dd * (1 - d) * (1 + d)); |
155 | } |
156 | |
157 | template <typename T, typename A, |
158 | typename U = typename utils::remove_reference<T>::type> |
159 | inline U elu_fwd(T s, A alpha) { |
160 | return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); |
161 | } |
162 | template <typename T, typename A, |
163 | typename U = typename utils::remove_reference<T>::type> |
164 | inline U elu_bwd(T dd, T s, A alpha) { |
165 | return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); |
166 | } |
167 | template <typename T, typename A, |
168 | typename U = typename utils::remove_reference<T>::type> |
169 | inline U elu_bwd_use_dst(T dd, T d, A alpha) { |
170 | return (U)(dd * (d > 0 ? 1 : d + alpha)); |
171 | } |
172 | |
173 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
174 | inline U square_fwd(T s) { |
175 | return s * s; |
176 | } |
177 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
178 | inline U square_bwd(T dd, T s) { |
179 | return dd * 2 * s; |
180 | } |
181 | |
182 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
183 | inline U abs_fwd(T s) { |
184 | return s > 0 ? s : (U)-s; |
185 | } |
186 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
187 | inline U abs_bwd(T dd, T s) { |
188 | return s > 0 ? dd : s < 0 ? (U)-dd : (U)0; |
189 | } |
190 | |
191 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
192 | inline U sqrt_fwd(T s) { |
193 | return (U)(::sqrtf((float)(s))); |
194 | } |
195 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
196 | inline U sqrt_bwd(T dd, T s) { |
197 | return (U)(dd / (2 * ::sqrtf((float)(s)))); |
198 | } |
199 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
200 | inline U sqrt_bwd_use_dst(T dd, T d) { |
201 | return (U)(dd / (2 * d)); |
202 | } |
203 | |
204 | template <typename T, typename A, |
205 | typename U = typename utils::remove_reference<T>::type> |
206 | inline U linear_fwd(T s, A alpha, A beta) { |
207 | return (U)(alpha * s + beta); |
208 | } |
209 | template <typename T, typename A, |
210 | typename U = typename utils::remove_reference<T>::type> |
211 | inline U linear_bwd(T dd, T s, A alpha, A beta) { |
212 | (void)s; |
213 | (void)beta; |
214 | return (U)(dd * alpha); |
215 | } |
216 | |
217 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
218 | inline U logistic_fwd(T s) { |
219 | // Here we avoid division/inverse by infinity as some architectures have |
220 | // non-standard behavior |
221 | float exp_overflow_bound = 88.72283172607421875; |
222 | float in = (float)-s; |
223 | return in < exp_overflow_bound ? (U)(1.f / (1.f + ::expf(in))) : 0.f; |
224 | } |
225 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
226 | inline U logistic_bwd(T dd, T s) { |
227 | float v = logistic_fwd<float>(s); |
228 | return (U)(dd * v * (1 - v)); |
229 | } |
230 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
231 | inline U logistic_bwd_use_dst(T dd, T d) { |
232 | return (U)(dd * d * (1 - d)); |
233 | } |
234 | |
235 | template <typename T, typename A, |
236 | typename U = typename utils::remove_reference<T>::type> |
237 | inline U soft_relu_fwd(T s, A alpha) { |
238 | float exp_overflow_bound = 88.72283172607421875; |
239 | float in = (float)s * (float)alpha; |
240 | float v = (in < exp_overflow_bound ? (U)(::log1pf(::expf(in))) : (U)in); |
241 | return (U)(v / alpha); |
242 | } |
243 | template <typename T, typename A, |
244 | typename U = typename utils::remove_reference<T>::type> |
245 | inline U soft_relu_bwd(T dd, T s, A alpha) { |
246 | float in = (float)s * (float)alpha; |
247 | return (U)(dd * logistic_fwd<float>(in)); |
248 | } |
249 | |
250 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
251 | inline U mish_fwd(T s) { |
252 | return s * tanh_fwd(soft_relu_fwd(s, 1.f)); |
253 | } |
254 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
255 | inline U mish_bwd(T dd, T s) { |
256 | const float tanh = tanh_fwd(soft_relu_fwd(s, 1.f)); |
257 | const float srelu_bwd = soft_relu_bwd(1.f, s, 1.f); |
258 | const float derivative = tanh + s * srelu_bwd * (1 - ::powf(tanh, 2.0f)); |
259 | return dd * derivative; |
260 | } |
261 | |
262 | template <typename T, typename A, |
263 | typename U = typename utils::remove_reference<T>::type> |
264 | inline U swish_fwd(T s, A alpha) { |
265 | return (U)(s * logistic_fwd<float>(alpha * s)); |
266 | } |
267 | template <typename T, typename A, |
268 | typename U = typename utils::remove_reference<T>::type> |
269 | inline U swish_bwd(T dd, T s, A alpha) { |
270 | float v = logistic_fwd<float>(alpha * s); |
271 | return dd * (v + s * alpha * v * (1 - v)); |
272 | } |
273 | |
274 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
275 | inline U exp_fwd(T s) { |
276 | return (U)(::expf((float)s)); |
277 | } |
278 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
279 | inline U exp_bwd(T dd, T s) { |
280 | return (U)(dd * (::expf((float)s))); |
281 | } |
282 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
283 | inline U exp_bwd_use_dst(T dd, T d) { |
284 | return (U)(dd * d); |
285 | } |
286 | |
287 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
288 | inline U gelu_tanh_fwd(T s) { |
289 | const float sqrt_2_over_pi = 0.79788458347320556640625f; |
290 | const float fitting_const = 0.044715f; |
291 | float v = tanh_fwd(sqrt_2_over_pi * s * (1 + fitting_const * s * s)); |
292 | return (U)(0.5 * s * (1. + v)); |
293 | } |
294 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
295 | inline U gelu_tanh_bwd(T dd, T s) { |
296 | const float sqrt_2_over_pi = 0.79788458347320556640625f; |
297 | const float fitting_const = 0.044715f; |
298 | float g = s * sqrt_2_over_pi * (1 + fitting_const * s * s); |
299 | float dg = sqrt_2_over_pi * (1 + 3 * fitting_const * s * s); |
300 | float v = tanh_fwd(g); |
301 | return (U)(dd * 0.5 * (1. + v) * (1. + s * (1 - v) * dg)); |
302 | } |
303 | |
304 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
305 | inline U log_fwd(T s) { |
306 | return (U)(::logf((float)s)); |
307 | } |
308 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
309 | inline U log_bwd(T dd, T s) { |
310 | return (U)(dd * (1.f / (float)s)); |
311 | } |
312 | |
313 | template <typename T, typename A, |
314 | typename U = typename utils::remove_reference<T>::type> |
315 | inline U clip_fwd(T s, A alpha, A beta) { |
316 | s = s > alpha ? s : (U)alpha; |
317 | return s > beta ? (U)beta : s; |
318 | } |
319 | template <typename T, typename A, |
320 | typename U = typename utils::remove_reference<T>::type> |
321 | inline U clip_bwd(T dd, T s, A alpha, A beta) { |
322 | return dd * (alpha < s && s <= beta ? 1 : 0); |
323 | } |
324 | |
325 | template <typename T, typename A, |
326 | typename U = typename utils::remove_reference<T>::type> |
327 | inline U clip_v2_fwd(T s, A alpha, A beta) { |
328 | s = s > alpha ? s : (U)alpha; |
329 | return s < beta ? s : (U)beta; |
330 | } |
331 | template <typename T, typename A, |
332 | typename U = typename utils::remove_reference<T>::type> |
333 | inline U clip_v2_bwd(T dd, T s, A alpha, A beta) { |
334 | return dd * (alpha < s && s < beta ? 1 : 0); |
335 | } |
336 | template <typename T, typename A, |
337 | typename U = typename utils::remove_reference<T>::type> |
338 | inline U clip_v2_bwd_use_dst(T dd, T d, A alpha, A beta) { |
339 | return clip_v2_bwd(dd, d, alpha, beta); |
340 | } |
341 | |
342 | template <typename T, typename A, |
343 | typename U = typename utils::remove_reference<T>::type> |
344 | inline U pow_fwd(T s, A alpha, A beta) { |
345 | return (U)(alpha * ::powf((float)s, beta)); |
346 | } |
347 | template <typename T, typename A, |
348 | typename U = typename utils::remove_reference<T>::type> |
349 | inline U pow_bwd(T dd, T s, A alpha, A beta) { |
350 | if (beta == 0) return 0; |
351 | |
352 | float v = pow_fwd(s, alpha * beta, beta - 1); |
353 | return (U)(dd * v); |
354 | } |
355 | |
356 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
357 | inline U gelu_erf_fwd(T s) { |
358 | const float sqrt_2_over_2 = 0.707106769084930419921875f; |
359 | float v = s * sqrt_2_over_2; |
360 | return (U)(0.5f * s * (1.f + ::erff(v))); |
361 | } |
362 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
363 | inline U gelu_erf_bwd(T dd, T s) { |
364 | const float two_over_sqrt_pi = 1.12837922573089599609375f; |
365 | const float sqrt_2_over_2 = 0.707106769084930419921875f; |
366 | float v = s * sqrt_2_over_2; |
367 | return (U)(dd * 0.5f |
368 | * (1.f + ::erff(v) + v * two_over_sqrt_pi * ::expf(-v * v))); |
369 | } |
370 | |
371 | template <typename T, typename A, |
372 | typename U = typename utils::remove_reference<T>::type> |
373 | inline U hardsigmoid_fwd(T s, A alpha, A beta) { |
374 | float v = alpha * s + beta; |
375 | return v <= 0.f ? 0.f : v >= 1.f ? 1.f : v; |
376 | } |
377 | template <typename T, typename A, |
378 | typename U = typename utils::remove_reference<T>::type> |
379 | inline U hardsigmoid_bwd(T dd, T s, A alpha, A beta) { |
380 | float v = alpha * s + beta; |
381 | return v <= 0.f ? 0.f : v >= 1.f ? 0.f : dd * alpha; |
382 | } |
383 | |
384 | template <typename T, typename A, |
385 | typename U = typename utils::remove_reference<T>::type> |
386 | inline U hardswish_fwd(T s, A alpha, A beta) { |
387 | return (U)(s * hardsigmoid_fwd(s, alpha, beta)); |
388 | } |
389 | template <typename T, typename A, |
390 | typename U = typename utils::remove_reference<T>::type> |
391 | inline U hardswish_bwd(T dd, T s, A alpha, A beta) { |
392 | float v = alpha * s + beta; |
393 | float w = 2.f * alpha * s + beta; |
394 | return v <= 0.f ? 0.f : v >= 1.f ? dd : dd * w; |
395 | } |
396 | |
397 | inline bool is_eltwise_ok( |
398 | data_type_t src_dt, alg_kind_t alg, float alpha, float beta) { |
399 | using namespace alg_kind; |
400 | using namespace utils; |
401 | |
402 | const bool eltwise_use_src |
403 | = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, |
404 | eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, |
405 | eltwise_soft_relu, eltwise_mish, eltwise_logistic, |
406 | eltwise_exp, eltwise_gelu_tanh, eltwise_hardsigmoid, |
407 | eltwise_hardswish, eltwise_swish, eltwise_log, |
408 | eltwise_clip, eltwise_clip_v2, eltwise_pow, |
409 | eltwise_gelu_erf, eltwise_round) |
410 | && IMPLICATION( |
411 | one_of(alg, eltwise_clip, eltwise_clip_v2), beta >= alpha) |
412 | && IMPLICATION(alg == eltwise_round, src_dt == dnnl_f32) |
413 | && IMPLICATION(one_of(src_dt, dnnl_s32, dnnl_s8, dnnl_u8), |
414 | one_of(alg, eltwise_relu, eltwise_linear)); |
415 | |
416 | const bool eltwise_use_dst |
417 | = one_of(alg, eltwise_relu_use_dst_for_bwd, |
418 | eltwise_tanh_use_dst_for_bwd, eltwise_elu_use_dst_for_bwd, |
419 | eltwise_sqrt_use_dst_for_bwd, |
420 | eltwise_logistic_use_dst_for_bwd, |
421 | eltwise_exp_use_dst_for_bwd, |
422 | eltwise_clip_v2_use_dst_for_bwd) |
423 | && IMPLICATION(one_of(alg, eltwise_relu_use_dst_for_bwd, |
424 | eltwise_elu_use_dst_for_bwd), |
425 | alpha >= 0) |
426 | && IMPLICATION( |
427 | alg == eltwise_clip_v2_use_dst_for_bwd, beta >= alpha); |
428 | |
429 | return eltwise_use_src || eltwise_use_dst; |
430 | } |
431 | |
432 | } // namespace math |
433 | } // namespace impl |
434 | } // namespace dnnl |
435 | |
436 | #endif |
437 | |