1 | #pragma once |
---|---|
2 | |
3 | #include <c10/core/ScalarType.h> |
4 | #include <torch/csrc/Export.h> |
5 | #include <torch/csrc/python_headers.h> |
6 | |
7 | const int DTYPE_NAME_LEN = 64; |
8 | |
9 | struct TORCH_API THPDtype { |
10 | PyObject_HEAD at::ScalarType scalar_type; |
11 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
12 | char name[DTYPE_NAME_LEN + 1]; |
13 | }; |
14 | |
15 | TORCH_API extern PyTypeObject THPDtypeType; |
16 | |
17 | inline bool THPDtype_Check(PyObject* obj) { |
18 | return Py_TYPE(obj) == &THPDtypeType; |
19 | } |
20 | |
21 | inline bool THPPythonScalarType_Check(PyObject* obj) { |
22 | return obj == (PyObject*)(&PyFloat_Type) || |
23 | obj == (PyObject*)(&PyBool_Type) || obj == (PyObject*)(&PyLong_Type); |
24 | } |
25 | |
26 | PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name); |
27 | |
28 | void THPDtype_init(PyObject* module); |
29 |