1#pragma once
2
3#include <c10/util/BFloat16-inl.h>
4#include <c10/util/math_compat.h>
5
6C10_CLANG_DIAGNOSTIC_PUSH()
7#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
8C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
9#endif
10
11namespace std {
12
13/// Used by vec256<c10::BFloat16>::map
14inline c10::BFloat16 acos(c10::BFloat16 a) {
15 return std::acos(float(a));
16}
17inline c10::BFloat16 asin(c10::BFloat16 a) {
18 return std::asin(float(a));
19}
20inline c10::BFloat16 atan(c10::BFloat16 a) {
21 return std::atan(float(a));
22}
23inline c10::BFloat16 erf(c10::BFloat16 a) {
24 return std::erf(float(a));
25}
26inline c10::BFloat16 erfc(c10::BFloat16 a) {
27 return std::erfc(float(a));
28}
29inline c10::BFloat16 exp(c10::BFloat16 a) {
30 return std::exp(float(a));
31}
32inline c10::BFloat16 expm1(c10::BFloat16 a) {
33 return std::expm1(float(a));
34}
35inline c10::BFloat16 log(c10::BFloat16 a) {
36 return std::log(float(a));
37}
38inline c10::BFloat16 log10(c10::BFloat16 a) {
39 return std::log10(float(a));
40}
41inline c10::BFloat16 log1p(c10::BFloat16 a) {
42 return std::log1p(float(a));
43}
44inline c10::BFloat16 log2(c10::BFloat16 a) {
45 return std::log2(float(a));
46}
47inline c10::BFloat16 ceil(c10::BFloat16 a) {
48 return std::ceil(float(a));
49}
50inline c10::BFloat16 cos(c10::BFloat16 a) {
51 return std::cos(float(a));
52}
53inline c10::BFloat16 floor(c10::BFloat16 a) {
54 return std::floor(float(a));
55}
56inline c10::BFloat16 nearbyint(c10::BFloat16 a) {
57 return std::nearbyint(float(a));
58}
59inline c10::BFloat16 sin(c10::BFloat16 a) {
60 return std::sin(float(a));
61}
62inline c10::BFloat16 tan(c10::BFloat16 a) {
63 return std::tan(float(a));
64}
65inline c10::BFloat16 sinh(c10::BFloat16 a) {
66 return std::sinh(float(a));
67}
68inline c10::BFloat16 cosh(c10::BFloat16 a) {
69 return std::cosh(float(a));
70}
71inline c10::BFloat16 tanh(c10::BFloat16 a) {
72 return std::tanh(float(a));
73}
74inline c10::BFloat16 trunc(c10::BFloat16 a) {
75 return std::trunc(float(a));
76}
77inline c10::BFloat16 lgamma(c10::BFloat16 a) {
78 return std::lgamma(float(a));
79}
80inline c10::BFloat16 sqrt(c10::BFloat16 a) {
81 return std::sqrt(float(a));
82}
83inline c10::BFloat16 rsqrt(c10::BFloat16 a) {
84 return 1.0 / std::sqrt(float(a));
85}
86inline c10::BFloat16 abs(c10::BFloat16 a) {
87 return std::abs(float(a));
88}
89#if defined(_MSC_VER) && defined(__CUDACC__)
90inline c10::BFloat16 pow(c10::BFloat16 a, double b) {
91 return std::pow(float(a), float(b));
92}
93#else
94inline c10::BFloat16 pow(c10::BFloat16 a, double b) {
95 return std::pow(float(a), b);
96}
97#endif
98inline c10::BFloat16 pow(c10::BFloat16 a, c10::BFloat16 b) {
99 return std::pow(float(a), float(b));
100}
101inline c10::BFloat16 fmod(c10::BFloat16 a, c10::BFloat16 b) {
102 return std::fmod(float(a), float(b));
103}
104
105/*
106 The following function is inspired from the implementation in `musl`
107 Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
108 ----------------------------------------------------------------------
109 Copyright © 2005-2020 Rich Felker, et al.
110
111 Permission is hereby granted, free of charge, to any person obtaining
112 a copy of this software and associated documentation files (the
113 "Software"), to deal in the Software without restriction, including
114 without limitation the rights to use, copy, modify, merge, publish,
115 distribute, sublicense, and/or sell copies of the Software, and to
116 permit persons to whom the Software is furnished to do so, subject to
117 the following conditions:
118
119 The above copyright notice and this permission notice shall be
120 included in all copies or substantial portions of the Software.
121
122 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
123 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
124 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
125 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
126 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
127 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
128 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
129 ----------------------------------------------------------------------
130 */
131C10_HOST_DEVICE inline c10::BFloat16 nextafter(
132 c10::BFloat16 from,
133 c10::BFloat16 to) {
134 // Reference:
135 // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
136 using int_repr_t = uint16_t;
137 using float_t = c10::BFloat16;
138 constexpr uint8_t bits = 16;
139 union {
140 float_t f;
141 int_repr_t i;
142 } ufrom = {from}, uto = {to};
143
144 // get a mask to get the sign bit i.e. MSB
145 int_repr_t sign_mask = int_repr_t{1} << (bits - 1);
146
147 // short-circuit: if either is NaN, return NaN
148 if (from != from || to != to) {
149 return from + to;
150 }
151
152 // short-circuit: if they are exactly the same.
153 if (ufrom.i == uto.i) {
154 return from;
155 }
156
157 // mask the sign-bit to zero i.e. positive
158 // equivalent to abs(x)
159 int_repr_t abs_from = ufrom.i & ~sign_mask;
160 int_repr_t abs_to = uto.i & ~sign_mask;
161 if (abs_from == 0) {
162 // if both are zero but with different sign,
163 // preserve the sign of `to`.
164 if (abs_to == 0) {
165 return to;
166 }
167 // smallest subnormal with sign of `to`.
168 ufrom.i = (uto.i & sign_mask) | int_repr_t{1};
169 return ufrom.f;
170 }
171
172 // if abs(from) > abs(to) or sign(from) != sign(to)
173 if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) {
174 ufrom.i--;
175 } else {
176 ufrom.i++;
177 }
178
179 return ufrom.f;
180}
181
182} // namespace std
183
184C10_CLANG_DIAGNOSTIC_POP()
185