1 | #pragma once |
2 | |
3 | #include <c10/util/BFloat16-inl.h> |
4 | #include <c10/util/math_compat.h> |
5 | |
6 | C10_CLANG_DIAGNOSTIC_PUSH() |
7 | #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") |
8 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion" ) |
9 | #endif |
10 | |
11 | namespace std { |
12 | |
13 | /// Used by vec256<c10::BFloat16>::map |
14 | inline c10::BFloat16 acos(c10::BFloat16 a) { |
15 | return std::acos(float(a)); |
16 | } |
17 | inline c10::BFloat16 asin(c10::BFloat16 a) { |
18 | return std::asin(float(a)); |
19 | } |
20 | inline c10::BFloat16 atan(c10::BFloat16 a) { |
21 | return std::atan(float(a)); |
22 | } |
23 | inline c10::BFloat16 erf(c10::BFloat16 a) { |
24 | return std::erf(float(a)); |
25 | } |
26 | inline c10::BFloat16 erfc(c10::BFloat16 a) { |
27 | return std::erfc(float(a)); |
28 | } |
29 | inline c10::BFloat16 exp(c10::BFloat16 a) { |
30 | return std::exp(float(a)); |
31 | } |
32 | inline c10::BFloat16 expm1(c10::BFloat16 a) { |
33 | return std::expm1(float(a)); |
34 | } |
35 | inline c10::BFloat16 log(c10::BFloat16 a) { |
36 | return std::log(float(a)); |
37 | } |
38 | inline c10::BFloat16 log10(c10::BFloat16 a) { |
39 | return std::log10(float(a)); |
40 | } |
41 | inline c10::BFloat16 log1p(c10::BFloat16 a) { |
42 | return std::log1p(float(a)); |
43 | } |
44 | inline c10::BFloat16 log2(c10::BFloat16 a) { |
45 | return std::log2(float(a)); |
46 | } |
47 | inline c10::BFloat16 ceil(c10::BFloat16 a) { |
48 | return std::ceil(float(a)); |
49 | } |
50 | inline c10::BFloat16 cos(c10::BFloat16 a) { |
51 | return std::cos(float(a)); |
52 | } |
53 | inline c10::BFloat16 floor(c10::BFloat16 a) { |
54 | return std::floor(float(a)); |
55 | } |
56 | inline c10::BFloat16 nearbyint(c10::BFloat16 a) { |
57 | return std::nearbyint(float(a)); |
58 | } |
59 | inline c10::BFloat16 sin(c10::BFloat16 a) { |
60 | return std::sin(float(a)); |
61 | } |
62 | inline c10::BFloat16 tan(c10::BFloat16 a) { |
63 | return std::tan(float(a)); |
64 | } |
65 | inline c10::BFloat16 sinh(c10::BFloat16 a) { |
66 | return std::sinh(float(a)); |
67 | } |
68 | inline c10::BFloat16 cosh(c10::BFloat16 a) { |
69 | return std::cosh(float(a)); |
70 | } |
71 | inline c10::BFloat16 tanh(c10::BFloat16 a) { |
72 | return std::tanh(float(a)); |
73 | } |
74 | inline c10::BFloat16 trunc(c10::BFloat16 a) { |
75 | return std::trunc(float(a)); |
76 | } |
77 | inline c10::BFloat16 lgamma(c10::BFloat16 a) { |
78 | return std::lgamma(float(a)); |
79 | } |
80 | inline c10::BFloat16 sqrt(c10::BFloat16 a) { |
81 | return std::sqrt(float(a)); |
82 | } |
83 | inline c10::BFloat16 rsqrt(c10::BFloat16 a) { |
84 | return 1.0 / std::sqrt(float(a)); |
85 | } |
86 | inline c10::BFloat16 abs(c10::BFloat16 a) { |
87 | return std::abs(float(a)); |
88 | } |
89 | #if defined(_MSC_VER) && defined(__CUDACC__) |
90 | inline c10::BFloat16 pow(c10::BFloat16 a, double b) { |
91 | return std::pow(float(a), float(b)); |
92 | } |
93 | #else |
94 | inline c10::BFloat16 pow(c10::BFloat16 a, double b) { |
95 | return std::pow(float(a), b); |
96 | } |
97 | #endif |
98 | inline c10::BFloat16 pow(c10::BFloat16 a, c10::BFloat16 b) { |
99 | return std::pow(float(a), float(b)); |
100 | } |
101 | inline c10::BFloat16 fmod(c10::BFloat16 a, c10::BFloat16 b) { |
102 | return std::fmod(float(a), float(b)); |
103 | } |
104 | |
105 | /* |
106 | The following function is inspired from the implementation in `musl` |
107 | Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT |
108 | ---------------------------------------------------------------------- |
109 | Copyright © 2005-2020 Rich Felker, et al. |
110 | |
111 | Permission is hereby granted, free of charge, to any person obtaining |
112 | a copy of this software and associated documentation files (the |
113 | "Software"), to deal in the Software without restriction, including |
114 | without limitation the rights to use, copy, modify, merge, publish, |
115 | distribute, sublicense, and/or sell copies of the Software, and to |
116 | permit persons to whom the Software is furnished to do so, subject to |
117 | the following conditions: |
118 | |
119 | The above copyright notice and this permission notice shall be |
120 | included in all copies or substantial portions of the Software. |
121 | |
122 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
123 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
124 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
125 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
126 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
127 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE |
128 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
129 | ---------------------------------------------------------------------- |
130 | */ |
131 | C10_HOST_DEVICE inline c10::BFloat16 nextafter( |
132 | c10::BFloat16 from, |
133 | c10::BFloat16 to) { |
134 | // Reference: |
135 | // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c |
136 | using int_repr_t = uint16_t; |
137 | using float_t = c10::BFloat16; |
138 | constexpr uint8_t bits = 16; |
139 | union { |
140 | float_t f; |
141 | int_repr_t i; |
142 | } ufrom = {from}, uto = {to}; |
143 | |
144 | // get a mask to get the sign bit i.e. MSB |
145 | int_repr_t sign_mask = int_repr_t{1} << (bits - 1); |
146 | |
147 | // short-circuit: if either is NaN, return NaN |
148 | if (from != from || to != to) { |
149 | return from + to; |
150 | } |
151 | |
152 | // short-circuit: if they are exactly the same. |
153 | if (ufrom.i == uto.i) { |
154 | return from; |
155 | } |
156 | |
157 | // mask the sign-bit to zero i.e. positive |
158 | // equivalent to abs(x) |
159 | int_repr_t abs_from = ufrom.i & ~sign_mask; |
160 | int_repr_t abs_to = uto.i & ~sign_mask; |
161 | if (abs_from == 0) { |
162 | // if both are zero but with different sign, |
163 | // preserve the sign of `to`. |
164 | if (abs_to == 0) { |
165 | return to; |
166 | } |
167 | // smallest subnormal with sign of `to`. |
168 | ufrom.i = (uto.i & sign_mask) | int_repr_t{1}; |
169 | return ufrom.f; |
170 | } |
171 | |
172 | // if abs(from) > abs(to) or sign(from) != sign(to) |
173 | if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) { |
174 | ufrom.i--; |
175 | } else { |
176 | ufrom.i++; |
177 | } |
178 | |
179 | return ufrom.f; |
180 | } |
181 | |
182 | } // namespace std |
183 | |
184 | C10_CLANG_DIAGNOSTIC_POP() |
185 | |