1#ifndef THP_UTILS_H
2#define THP_UTILS_H
3
4#include <ATen/ATen.h>
5#include <torch/csrc/THConcat.h>
6#include <torch/csrc/utils/object_ptr.h>
7#include <torch/csrc/utils/python_compat.h>
8#include <torch/csrc/utils/python_numbers.h>
9#include <string>
10#include <type_traits>
11#include <vector>
12
13#ifdef USE_CUDA
14#include <c10/cuda/CUDAStream.h>
15#endif
16
17#define THPUtils_(NAME) TH_CONCAT_4(THP, Real, Utils_, NAME)
18
19#define THPUtils_typename(obj) (Py_TYPE(obj)->tp_name)
20
21#if defined(__GNUC__) || defined(__ICL) || defined(__clang__)
22#define THP_EXPECT(x, y) (__builtin_expect((x), (y)))
23#else
24#define THP_EXPECT(x, y) (x)
25#endif
26
27#define THPUtils_checkReal_FLOAT(object) \
28 (PyFloat_Check(object) || PyLong_Check(object))
29
30#define THPUtils_unpackReal_FLOAT(object) \
31 (PyFloat_Check(object) ? PyFloat_AsDouble(object) \
32 : PyLong_Check(object) \
33 ? PyLong_AsLongLong(object) \
34 : (throw std::runtime_error("Could not parse real"), 0))
35
36#define THPUtils_checkReal_INT(object) PyLong_Check(object)
37
38#define THPUtils_unpackReal_INT(object) \
39 (PyLong_Check(object) \
40 ? PyLong_AsLongLong(object) \
41 : (throw std::runtime_error("Could not parse real"), 0))
42
43#define THPUtils_unpackReal_BOOL(object) \
44 (PyBool_Check(object) \
45 ? object \
46 : (throw std::runtime_error("Could not parse real"), Py_False))
47
48#define THPUtils_unpackReal_COMPLEX(object) \
49 (PyComplex_Check(object) \
50 ? (c10::complex<double>( \
51 PyComplex_RealAsDouble(object), PyComplex_ImagAsDouble(object))) \
52 : PyFloat_Check(object) \
53 ? (c10::complex<double>(PyFloat_AsDouble(object), 0)) \
54 : PyLong_Check(object) \
55 ? (c10::complex<double>(PyLong_AsLongLong(object), 0)) \
56 : (throw std::runtime_error("Could not parse real"), \
57 c10::complex<double>(0, 0)))
58
59#define THPUtils_checkReal_BOOL(object) PyBool_Check(object)
60
61#define THPUtils_checkReal_COMPLEX(object) \
62 PyComplex_Check(object) || PyFloat_Check(object) || PyLong_Check(object) || \
63 PyInt_Check(object)
64
65#define THPUtils_newReal_FLOAT(value) PyFloat_FromDouble(value)
66#define THPUtils_newReal_INT(value) PyInt_FromLong(value)
67
68#define THPUtils_newReal_BOOL(value) PyBool_FromLong(value)
69
70#define THPUtils_newReal_COMPLEX(value) \
71 PyComplex_FromDoubles(value.real(), value.imag())
72
73#define THPDoubleUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
74#define THPDoubleUtils_unpackReal(object) \
75 (double)THPUtils_unpackReal_FLOAT(object)
76#define THPDoubleUtils_newReal(value) THPUtils_newReal_FLOAT(value)
77#define THPFloatUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
78#define THPFloatUtils_unpackReal(object) \
79 (float)THPUtils_unpackReal_FLOAT(object)
80#define THPFloatUtils_newReal(value) THPUtils_newReal_FLOAT(value)
81#define THPHalfUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
82#define THPHalfUtils_unpackReal(object) \
83 (at::Half) THPUtils_unpackReal_FLOAT(object)
84#define THPHalfUtils_newReal(value) PyFloat_FromDouble(value)
85#define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value)
86#define THPComplexDoubleUtils_checkReal(object) \
87 THPUtils_checkReal_COMPLEX(object)
88#define THPComplexDoubleUtils_unpackReal(object) \
89 THPUtils_unpackReal_COMPLEX(object)
90#define THPComplexDoubleUtils_newReal(value) THPUtils_newReal_COMPLEX(value)
91#define THPComplexFloatUtils_checkReal(object) \
92 THPUtils_checkReal_COMPLEX(object)
93#define THPComplexFloatUtils_unpackReal(object) \
94 (c10::complex<float>)THPUtils_unpackReal_COMPLEX(object)
95#define THPComplexFloatUtils_newReal(value) THPUtils_newReal_COMPLEX(value)
96#define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object)
97#define THPBFloat16Utils_unpackReal(object) \
98 (at::BFloat16) THPUtils_unpackReal_FLOAT(object)
99#define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value)
100#define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value)
101
102#define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object)
103#define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object)
104#define THPBoolUtils_newReal(value) THPUtils_newReal_BOOL(value)
105#define THPBoolUtils_checkAccreal(object) THPUtils_checkReal_BOOL(object)
106#define THPBoolUtils_unpackAccreal(object) \
107 (int64_t) THPUtils_unpackReal_BOOL(object)
108#define THPBoolUtils_newAccreal(value) THPUtils_newReal_BOOL(value)
109#define THPLongUtils_checkReal(object) THPUtils_checkReal_INT(object)
110#define THPLongUtils_unpackReal(object) \
111 (int64_t) THPUtils_unpackReal_INT(object)
112#define THPLongUtils_newReal(value) THPUtils_newReal_INT(value)
113#define THPIntUtils_checkReal(object) THPUtils_checkReal_INT(object)
114#define THPIntUtils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
115#define THPIntUtils_newReal(value) THPUtils_newReal_INT(value)
116#define THPShortUtils_checkReal(object) THPUtils_checkReal_INT(object)
117#define THPShortUtils_unpackReal(object) (short)THPUtils_unpackReal_INT(object)
118#define THPShortUtils_newReal(value) THPUtils_newReal_INT(value)
119#define THPCharUtils_checkReal(object) THPUtils_checkReal_INT(object)
120#define THPCharUtils_unpackReal(object) (char)THPUtils_unpackReal_INT(object)
121#define THPCharUtils_newReal(value) THPUtils_newReal_INT(value)
122#define THPByteUtils_checkReal(object) THPUtils_checkReal_INT(object)
123#define THPByteUtils_unpackReal(object) \
124 (unsigned char)THPUtils_unpackReal_INT(object)
125#define THPByteUtils_newReal(value) THPUtils_newReal_INT(value)
126// quantized types
127#define THPQUInt8Utils_checkReal(object) THPUtils_checkReal_INT(object)
128#define THPQUInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
129#define THPQUInt8Utils_newReal(value) THPUtils_newReal_INT(value)
130#define THPQInt8Utils_checkReal(object) THPUtils_checkReal_INT(object)
131#define THPQInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
132#define THPQInt8Utils_newReal(value) THPUtils_newReal_INT(value)
133#define THPQInt32Utils_checkReal(object) THPUtils_checkReal_INT(object)
134#define THPQInt32Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
135#define THPQInt32Utils_newReal(value) THPUtils_newReal_INT(value)
136#define THPQUInt4x2Utils_checkReal(object) THPUtils_checkReal_INT(object)
137#define THPQUInt4x2Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
138#define THPQUInt4x2Utils_newReal(value) THPUtils_newReal_INT(value)
139#define THPQUInt2x4Utils_checkReal(object) THPUtils_checkReal_INT(object)
140#define THPQUInt2x4Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object)
141#define THPQUInt2x4Utils_newReal(value) THPUtils_newReal_INT(value)
142
143/*
144 From https://github.com/python/cpython/blob/v3.7.0/Modules/xxsubtype.c
145 If compiled as a shared library, some compilers don't allow addresses of
146 Python objects defined in other libraries to be used in static PyTypeObject
147 initializers. The DEFERRED_ADDRESS macro is used to tag the slots where such
148 addresses appear; the module init function that adds the PyTypeObject to the
149 module must fill in the tagged slots at runtime. The argument is for
150 documentation -- the macro ignores it.
151*/
152#define DEFERRED_ADDRESS(ADDR) nullptr
153
154#define THPUtils_assert(cond, ...) \
155 THPUtils_assertRet(nullptr, cond, __VA_ARGS__)
156#define THPUtils_assertRet(value, cond, ...) \
157 if (THP_EXPECT(!(cond), 0)) { \
158 THPUtils_setError(__VA_ARGS__); \
159 return value; \
160 }
161TORCH_PYTHON_API void THPUtils_setError(const char* format, ...);
162TORCH_PYTHON_API void THPUtils_invalidArguments(
163 PyObject* given_args,
164 PyObject* given_kwargs,
165 const char* function_name,
166 size_t num_options,
167 ...);
168
169bool THPUtils_checkIntTuple(PyObject* arg);
170std::vector<int> THPUtils_unpackIntTuple(PyObject* arg);
171
172void THPUtils_addPyMethodDefs(
173 std::vector<PyMethodDef>& vector,
174 PyMethodDef* methods);
175
176int THPUtils_getCallable(PyObject* arg, PyObject** result);
177
178typedef THPPointer<THPGenerator> THPGeneratorPtr;
179typedef class THPPointer<THPStorage> THPStoragePtr;
180
181std::vector<int64_t> THPUtils_unpackLongs(PyObject* arg);
182PyObject* THPUtils_dispatchStateless(
183 PyObject* tensor,
184 const char* name,
185 PyObject* args,
186 PyObject* kwargs);
187
188template <typename _real, typename = void>
189struct mod_traits {};
190
191template <typename _real>
192struct mod_traits<
193 _real,
194 typename std::enable_if<std::is_floating_point<_real>::value>::type> {
195 static _real mod(_real a, _real b) {
196 return fmod(a, b);
197 }
198};
199
200template <typename _real>
201struct mod_traits<
202 _real,
203 typename std::enable_if<std::is_integral<_real>::value>::type> {
204 static _real mod(_real a, _real b) {
205 return a % b;
206 }
207};
208
209void setBackCompatBroadcastWarn(bool warn);
210bool getBackCompatBroadcastWarn();
211
212void setBackCompatKeepdimWarn(bool warn);
213bool getBackCompatKeepdimWarn();
214bool maybeThrowBackCompatKeepdimWarn(char* func);
215
216// NB: This is in torch/csrc/cuda/utils.cpp, for whatever reason
217#ifdef USE_CUDA
218std::vector<c10::optional<at::cuda::CUDAStream>>
219THPUtils_PySequence_to_CUDAStreamList(PyObject* obj);
220#endif
221
222void storage_copy(at::Storage dst, at::Storage src, bool non_blocking = false);
223void storage_fill(at::Storage self, uint8_t value);
224void storage_set(at::Storage self, ptrdiff_t idx, uint8_t value);
225uint8_t storage_get(at::Storage self, ptrdiff_t idx);
226
227#endif
228