1 | #ifndef THP_UTILS_H |
2 | #define THP_UTILS_H |
3 | |
4 | #include <ATen/ATen.h> |
5 | #include <torch/csrc/THConcat.h> |
6 | #include <torch/csrc/utils/object_ptr.h> |
7 | #include <torch/csrc/utils/python_compat.h> |
8 | #include <torch/csrc/utils/python_numbers.h> |
9 | #include <string> |
10 | #include <type_traits> |
11 | #include <vector> |
12 | |
13 | #ifdef USE_CUDA |
14 | #include <c10/cuda/CUDAStream.h> |
15 | #endif |
16 | |
17 | #define THPUtils_(NAME) TH_CONCAT_4(THP, Real, Utils_, NAME) |
18 | |
19 | #define THPUtils_typename(obj) (Py_TYPE(obj)->tp_name) |
20 | |
21 | #if defined(__GNUC__) || defined(__ICL) || defined(__clang__) |
22 | #define THP_EXPECT(x, y) (__builtin_expect((x), (y))) |
23 | #else |
24 | #define THP_EXPECT(x, y) (x) |
25 | #endif |
26 | |
27 | #define THPUtils_checkReal_FLOAT(object) \ |
28 | (PyFloat_Check(object) || PyLong_Check(object)) |
29 | |
30 | #define THPUtils_unpackReal_FLOAT(object) \ |
31 | (PyFloat_Check(object) ? PyFloat_AsDouble(object) \ |
32 | : PyLong_Check(object) \ |
33 | ? PyLong_AsLongLong(object) \ |
34 | : (throw std::runtime_error("Could not parse real"), 0)) |
35 | |
36 | #define THPUtils_checkReal_INT(object) PyLong_Check(object) |
37 | |
38 | #define THPUtils_unpackReal_INT(object) \ |
39 | (PyLong_Check(object) \ |
40 | ? PyLong_AsLongLong(object) \ |
41 | : (throw std::runtime_error("Could not parse real"), 0)) |
42 | |
43 | #define THPUtils_unpackReal_BOOL(object) \ |
44 | (PyBool_Check(object) \ |
45 | ? object \ |
46 | : (throw std::runtime_error("Could not parse real"), Py_False)) |
47 | |
48 | #define THPUtils_unpackReal_COMPLEX(object) \ |
49 | (PyComplex_Check(object) \ |
50 | ? (c10::complex<double>( \ |
51 | PyComplex_RealAsDouble(object), PyComplex_ImagAsDouble(object))) \ |
52 | : PyFloat_Check(object) \ |
53 | ? (c10::complex<double>(PyFloat_AsDouble(object), 0)) \ |
54 | : PyLong_Check(object) \ |
55 | ? (c10::complex<double>(PyLong_AsLongLong(object), 0)) \ |
56 | : (throw std::runtime_error("Could not parse real"), \ |
57 | c10::complex<double>(0, 0))) |
58 | |
59 | #define THPUtils_checkReal_BOOL(object) PyBool_Check(object) |
60 | |
61 | #define THPUtils_checkReal_COMPLEX(object) \ |
62 | PyComplex_Check(object) || PyFloat_Check(object) || PyLong_Check(object) || \ |
63 | PyInt_Check(object) |
64 | |
65 | #define THPUtils_newReal_FLOAT(value) PyFloat_FromDouble(value) |
66 | #define THPUtils_newReal_INT(value) PyInt_FromLong(value) |
67 | |
68 | #define THPUtils_newReal_BOOL(value) PyBool_FromLong(value) |
69 | |
70 | #define THPUtils_newReal_COMPLEX(value) \ |
71 | PyComplex_FromDoubles(value.real(), value.imag()) |
72 | |
73 | #define THPDoubleUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) |
74 | #define THPDoubleUtils_unpackReal(object) \ |
75 | (double)THPUtils_unpackReal_FLOAT(object) |
76 | #define THPDoubleUtils_newReal(value) THPUtils_newReal_FLOAT(value) |
77 | #define THPFloatUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) |
78 | #define THPFloatUtils_unpackReal(object) \ |
79 | (float)THPUtils_unpackReal_FLOAT(object) |
80 | #define THPFloatUtils_newReal(value) THPUtils_newReal_FLOAT(value) |
81 | #define THPHalfUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) |
82 | #define THPHalfUtils_unpackReal(object) \ |
83 | (at::Half) THPUtils_unpackReal_FLOAT(object) |
84 | #define THPHalfUtils_newReal(value) PyFloat_FromDouble(value) |
85 | #define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value) |
86 | #define THPComplexDoubleUtils_checkReal(object) \ |
87 | THPUtils_checkReal_COMPLEX(object) |
88 | #define THPComplexDoubleUtils_unpackReal(object) \ |
89 | THPUtils_unpackReal_COMPLEX(object) |
90 | #define THPComplexDoubleUtils_newReal(value) THPUtils_newReal_COMPLEX(value) |
91 | #define THPComplexFloatUtils_checkReal(object) \ |
92 | THPUtils_checkReal_COMPLEX(object) |
93 | #define THPComplexFloatUtils_unpackReal(object) \ |
94 | (c10::complex<float>)THPUtils_unpackReal_COMPLEX(object) |
95 | #define THPComplexFloatUtils_newReal(value) THPUtils_newReal_COMPLEX(value) |
96 | #define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object) |
97 | #define THPBFloat16Utils_unpackReal(object) \ |
98 | (at::BFloat16) THPUtils_unpackReal_FLOAT(object) |
99 | #define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value) |
100 | #define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value) |
101 | |
102 | #define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object) |
103 | #define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object) |
104 | #define THPBoolUtils_newReal(value) THPUtils_newReal_BOOL(value) |
105 | #define THPBoolUtils_checkAccreal(object) THPUtils_checkReal_BOOL(object) |
106 | #define THPBoolUtils_unpackAccreal(object) \ |
107 | (int64_t) THPUtils_unpackReal_BOOL(object) |
108 | #define THPBoolUtils_newAccreal(value) THPUtils_newReal_BOOL(value) |
109 | #define THPLongUtils_checkReal(object) THPUtils_checkReal_INT(object) |
110 | #define THPLongUtils_unpackReal(object) \ |
111 | (int64_t) THPUtils_unpackReal_INT(object) |
112 | #define THPLongUtils_newReal(value) THPUtils_newReal_INT(value) |
113 | #define THPIntUtils_checkReal(object) THPUtils_checkReal_INT(object) |
114 | #define THPIntUtils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) |
115 | #define THPIntUtils_newReal(value) THPUtils_newReal_INT(value) |
116 | #define THPShortUtils_checkReal(object) THPUtils_checkReal_INT(object) |
117 | #define THPShortUtils_unpackReal(object) (short)THPUtils_unpackReal_INT(object) |
118 | #define THPShortUtils_newReal(value) THPUtils_newReal_INT(value) |
119 | #define THPCharUtils_checkReal(object) THPUtils_checkReal_INT(object) |
120 | #define THPCharUtils_unpackReal(object) (char)THPUtils_unpackReal_INT(object) |
121 | #define THPCharUtils_newReal(value) THPUtils_newReal_INT(value) |
122 | #define THPByteUtils_checkReal(object) THPUtils_checkReal_INT(object) |
123 | #define THPByteUtils_unpackReal(object) \ |
124 | (unsigned char)THPUtils_unpackReal_INT(object) |
125 | #define THPByteUtils_newReal(value) THPUtils_newReal_INT(value) |
126 | // quantized types |
127 | #define THPQUInt8Utils_checkReal(object) THPUtils_checkReal_INT(object) |
128 | #define THPQUInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) |
129 | #define THPQUInt8Utils_newReal(value) THPUtils_newReal_INT(value) |
130 | #define THPQInt8Utils_checkReal(object) THPUtils_checkReal_INT(object) |
131 | #define THPQInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) |
132 | #define THPQInt8Utils_newReal(value) THPUtils_newReal_INT(value) |
133 | #define THPQInt32Utils_checkReal(object) THPUtils_checkReal_INT(object) |
134 | #define THPQInt32Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) |
135 | #define THPQInt32Utils_newReal(value) THPUtils_newReal_INT(value) |
136 | #define THPQUInt4x2Utils_checkReal(object) THPUtils_checkReal_INT(object) |
137 | #define THPQUInt4x2Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) |
138 | #define THPQUInt4x2Utils_newReal(value) THPUtils_newReal_INT(value) |
139 | #define THPQUInt2x4Utils_checkReal(object) THPUtils_checkReal_INT(object) |
140 | #define THPQUInt2x4Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) |
141 | #define THPQUInt2x4Utils_newReal(value) THPUtils_newReal_INT(value) |
142 | |
143 | /* |
144 | From https://github.com/python/cpython/blob/v3.7.0/Modules/xxsubtype.c |
145 | If compiled as a shared library, some compilers don't allow addresses of |
146 | Python objects defined in other libraries to be used in static PyTypeObject |
147 | initializers. The DEFERRED_ADDRESS macro is used to tag the slots where such |
148 | addresses appear; the module init function that adds the PyTypeObject to the |
149 | module must fill in the tagged slots at runtime. The argument is for |
150 | documentation -- the macro ignores it. |
151 | */ |
152 | #define DEFERRED_ADDRESS(ADDR) nullptr |
153 | |
154 | #define THPUtils_assert(cond, ...) \ |
155 | THPUtils_assertRet(nullptr, cond, __VA_ARGS__) |
156 | #define THPUtils_assertRet(value, cond, ...) \ |
157 | if (THP_EXPECT(!(cond), 0)) { \ |
158 | THPUtils_setError(__VA_ARGS__); \ |
159 | return value; \ |
160 | } |
161 | TORCH_PYTHON_API void THPUtils_setError(const char* format, ...); |
162 | TORCH_PYTHON_API void THPUtils_invalidArguments( |
163 | PyObject* given_args, |
164 | PyObject* given_kwargs, |
165 | const char* function_name, |
166 | size_t num_options, |
167 | ...); |
168 | |
169 | bool THPUtils_checkIntTuple(PyObject* arg); |
170 | std::vector<int> THPUtils_unpackIntTuple(PyObject* arg); |
171 | |
172 | void THPUtils_addPyMethodDefs( |
173 | std::vector<PyMethodDef>& vector, |
174 | PyMethodDef* methods); |
175 | |
176 | int THPUtils_getCallable(PyObject* arg, PyObject** result); |
177 | |
178 | typedef THPPointer<THPGenerator> THPGeneratorPtr; |
179 | typedef class THPPointer<THPStorage> THPStoragePtr; |
180 | |
181 | std::vector<int64_t> THPUtils_unpackLongs(PyObject* arg); |
182 | PyObject* THPUtils_dispatchStateless( |
183 | PyObject* tensor, |
184 | const char* name, |
185 | PyObject* args, |
186 | PyObject* kwargs); |
187 | |
188 | template <typename _real, typename = void> |
189 | struct mod_traits {}; |
190 | |
191 | template <typename _real> |
192 | struct mod_traits< |
193 | _real, |
194 | typename std::enable_if<std::is_floating_point<_real>::value>::type> { |
195 | static _real mod(_real a, _real b) { |
196 | return fmod(a, b); |
197 | } |
198 | }; |
199 | |
200 | template <typename _real> |
201 | struct mod_traits< |
202 | _real, |
203 | typename std::enable_if<std::is_integral<_real>::value>::type> { |
204 | static _real mod(_real a, _real b) { |
205 | return a % b; |
206 | } |
207 | }; |
208 | |
209 | void setBackCompatBroadcastWarn(bool warn); |
210 | bool getBackCompatBroadcastWarn(); |
211 | |
212 | void setBackCompatKeepdimWarn(bool warn); |
213 | bool getBackCompatKeepdimWarn(); |
214 | bool maybeThrowBackCompatKeepdimWarn(char* func); |
215 | |
216 | // NB: This is in torch/csrc/cuda/utils.cpp, for whatever reason |
217 | #ifdef USE_CUDA |
218 | std::vector<c10::optional<at::cuda::CUDAStream>> |
219 | THPUtils_PySequence_to_CUDAStreamList(PyObject* obj); |
220 | #endif |
221 | |
222 | void storage_copy(at::Storage dst, at::Storage src, bool non_blocking = false); |
223 | void storage_fill(at::Storage self, uint8_t value); |
224 | void storage_set(at::Storage self, ptrdiff_t idx, uint8_t value); |
225 | uint8_t storage_get(at::Storage self, ptrdiff_t idx); |
226 | |
227 | #endif |
228 | |