1 | #pragma once |
2 | |
3 | #include <c10/macros/Macros.h> |
4 | #include <c10/util/BFloat16.h> |
5 | #include <c10/util/Half.h> |
6 | |
7 | C10_CLANG_DIAGNOSTIC_PUSH() |
8 | #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") |
9 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion" ) |
10 | #endif |
11 | |
12 | namespace c10 { |
13 | // TODO: Replace me with inline constexpr variable when C++17 becomes available |
14 | namespace detail { |
15 | template <typename T> |
16 | C10_HOST_DEVICE inline constexpr T e() { |
17 | return static_cast<T>(2.718281828459045235360287471352662); |
18 | } |
19 | |
20 | template <typename T> |
21 | C10_HOST_DEVICE inline constexpr T euler() { |
22 | return static_cast<T>(0.577215664901532860606512090082402); |
23 | } |
24 | |
25 | template <typename T> |
26 | C10_HOST_DEVICE inline constexpr T frac_1_pi() { |
27 | return static_cast<T>(0.318309886183790671537767526745028); |
28 | } |
29 | |
30 | template <typename T> |
31 | C10_HOST_DEVICE inline constexpr T frac_1_sqrt_pi() { |
32 | return static_cast<T>(0.564189583547756286948079451560772); |
33 | } |
34 | |
35 | template <typename T> |
36 | C10_HOST_DEVICE inline constexpr T frac_sqrt_2() { |
37 | return static_cast<T>(0.707106781186547524400844362104849); |
38 | } |
39 | |
40 | template <typename T> |
41 | C10_HOST_DEVICE inline constexpr T frac_sqrt_3() { |
42 | return static_cast<T>(0.577350269189625764509148780501957); |
43 | } |
44 | |
45 | template <typename T> |
46 | C10_HOST_DEVICE inline constexpr T golden_ratio() { |
47 | return static_cast<T>(1.618033988749894848204586834365638); |
48 | } |
49 | |
50 | template <typename T> |
51 | C10_HOST_DEVICE inline constexpr T ln_10() { |
52 | return static_cast<T>(2.302585092994045684017991454684364); |
53 | } |
54 | |
55 | template <typename T> |
56 | C10_HOST_DEVICE inline constexpr T ln_2() { |
57 | return static_cast<T>(0.693147180559945309417232121458176); |
58 | } |
59 | |
60 | template <typename T> |
61 | C10_HOST_DEVICE inline constexpr T log_10_e() { |
62 | return static_cast<T>(0.434294481903251827651128918916605); |
63 | } |
64 | |
65 | template <typename T> |
66 | C10_HOST_DEVICE inline constexpr T log_2_e() { |
67 | return static_cast<T>(1.442695040888963407359924681001892); |
68 | } |
69 | |
70 | template <typename T> |
71 | C10_HOST_DEVICE inline constexpr T pi() { |
72 | return static_cast<T>(3.141592653589793238462643383279502); |
73 | } |
74 | |
75 | template <typename T> |
76 | C10_HOST_DEVICE inline constexpr T sqrt_2() { |
77 | return static_cast<T>(1.414213562373095048801688724209698); |
78 | } |
79 | |
80 | template <typename T> |
81 | C10_HOST_DEVICE inline constexpr T sqrt_3() { |
82 | return static_cast<T>(1.732050807568877293527446341505872); |
83 | } |
84 | |
85 | template <> |
86 | C10_HOST_DEVICE inline constexpr BFloat16 pi<BFloat16>() { |
87 | // According to |
88 | // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values |
89 | // pi is encoded as 4049 |
90 | return BFloat16(0x4049, BFloat16::from_bits()); |
91 | } |
92 | |
93 | template <> |
94 | C10_HOST_DEVICE inline constexpr Half pi<Half>() { |
95 | return Half(0x4248, Half::from_bits()); |
96 | } |
97 | } // namespace detail |
98 | |
99 | template <typename T> |
100 | constexpr T e = c10::detail::e<T>(); |
101 | |
102 | template <typename T> |
103 | constexpr T euler = c10::detail::euler<T>(); |
104 | |
105 | template <typename T> |
106 | constexpr T frac_1_pi = c10::detail::frac_1_pi<T>(); |
107 | |
108 | template <typename T> |
109 | constexpr T frac_1_sqrt_pi = c10::detail::frac_1_sqrt_pi<T>(); |
110 | |
111 | template <typename T> |
112 | constexpr T frac_sqrt_2 = c10::detail::frac_sqrt_2<T>(); |
113 | |
114 | template <typename T> |
115 | constexpr T frac_sqrt_3 = c10::detail::frac_sqrt_3<T>(); |
116 | |
117 | template <typename T> |
118 | constexpr T golden_ratio = c10::detail::golden_ratio<T>(); |
119 | |
120 | template <typename T> |
121 | constexpr T ln_10 = c10::detail::ln_10<T>(); |
122 | |
123 | template <typename T> |
124 | constexpr T ln_2 = c10::detail::ln_2<T>(); |
125 | |
126 | template <typename T> |
127 | constexpr T log_10_e = c10::detail::log_10_e<T>(); |
128 | |
129 | template <typename T> |
130 | constexpr T log_2_e = c10::detail::log_2_e<T>(); |
131 | |
132 | template <typename T> |
133 | constexpr T pi = c10::detail::pi<T>(); |
134 | |
135 | template <typename T> |
136 | constexpr T sqrt_2 = c10::detail::sqrt_2<T>(); |
137 | |
138 | template <typename T> |
139 | constexpr T sqrt_3 = c10::detail::sqrt_3<T>(); |
140 | } // namespace c10 |
141 | |
142 | C10_CLANG_DIAGNOSTIC_POP() |
143 | |