1#pragma once
2
3#ifdef __HIPCC__
4#include <hip/hip_runtime.h>
5#endif
6
7#include <c10/macros/Macros.h>
8#include <c10/util/BFloat16.h>
9#include <c10/util/Half.h>
10#include <c10/util/complex.h>
11
12#include <cmath>
13#include <type_traits>
14
15namespace at {
16
17// std::isnan isn't performant to use on integral types; it will
18// (uselessly) convert to floating point and then do the test.
19// This function is.
20
21template <
22 typename T,
23 typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
24inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
25 return false;
26}
27
28template <
29 typename T,
30 typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
31inline C10_HOST_DEVICE bool _isnan(T val) {
32#if defined(__CUDACC__) || defined(__HIPCC__)
33 return ::isnan(val);
34#else
35 return std::isnan(val);
36#endif
37}
38
39template <
40 typename T,
41 typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
42inline bool _isnan(T val) {
43 return std::isnan(val.real()) || std::isnan(val.imag());
44}
45
46template <
47 typename T,
48 typename std::enable_if<std::is_same<T, at::Half>::value, int>::type = 0>
49inline C10_HOST_DEVICE bool _isnan(T val) {
50 return at::_isnan(static_cast<float>(val));
51}
52
53template <
54 typename T,
55 typename std::enable_if<std::is_same<T, at::BFloat16>::value, int>::type =
56 0>
57inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
58 return at::_isnan(static_cast<float>(val));
59}
60
61inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
62 return at::_isnan(static_cast<float>(val));
63}
64
65// std::isinf isn't performant to use on integral types; it will
66// (uselessly) convert to floating point and then do the test.
67// This function is.
68
69template <
70 typename T,
71 typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
72inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
73 return false;
74}
75
76template <
77 typename T,
78 typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
79inline C10_HOST_DEVICE bool _isinf(T val) {
80#if defined(__CUDACC__) || defined(__HIPCC__)
81 return ::isinf(val);
82#else
83 return std::isinf(val);
84#endif
85}
86
87inline C10_HOST_DEVICE bool _isinf(at::Half val) {
88 return at::_isinf(static_cast<float>(val));
89}
90
91inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
92 return at::_isinf(static_cast<float>(val));
93}
94
95template <typename T>
96C10_HOST_DEVICE inline T exp(T x) {
97 static_assert(
98 !std::is_same<T, double>::value,
99 "this template must be used with float or less precise type");
100#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
101 // use __expf fast approximation for peak bandwidth
102 return __expf(x);
103#else
104 return ::exp(x);
105#endif
106}
107
108template <>
109C10_HOST_DEVICE inline double exp<double>(double x) {
110 return ::exp(x);
111}
112
113template <typename T>
114C10_HOST_DEVICE inline T log(T x) {
115 static_assert(
116 !std::is_same<T, double>::value,
117 "this template must be used with float or less precise type");
118#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
119 // use __logf fast approximation for peak bandwidth
120 return __logf(x);
121#else
122 return ::log(x);
123#endif
124}
125
126template <>
127C10_HOST_DEVICE inline double log<double>(double x) {
128 return ::log(x);
129}
130
131template <typename T>
132C10_HOST_DEVICE inline T log1p(T x) {
133 static_assert(
134 !std::is_same<T, double>::value,
135 "this template must be used with float or less precise type");
136#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
137 // use __logf fast approximation for peak bandwidth
138 // NOTE: There is no __log1pf so unfortunately we lose precision.
139 return __logf(1.0f + x);
140#else
141 return ::log1p(x);
142#endif
143}
144
145template <>
146C10_HOST_DEVICE inline double log1p<double>(double x) {
147 return ::log1p(x);
148}
149
150template <typename T>
151C10_HOST_DEVICE inline T tan(T x) {
152 static_assert(
153 !std::is_same<T, double>::value,
154 "this template must be used with float or less precise type");
155#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
156 // use __tanf fast approximation for peak bandwidth
157 return __tanf(x);
158#else
159 return ::tan(x);
160#endif
161}
162
163template <>
164C10_HOST_DEVICE inline double tan<double>(double x) {
165 return ::tan(x);
166}
167
168} // namespace at
169