1#include <torch/csrc/Dtype.h>
2
3#include <structmember.h>
4#include <torch/csrc/Exceptions.h>
5#include <torch/csrc/utils/object_ptr.h>
6#include <torch/csrc/utils/python_strings.h>
7#include <torch/csrc/utils/tensor_dtypes.h>
8#include <torch/csrc/utils/tensor_types.h>
9#include <cstring>
10
11#include <torch/csrc/Exceptions.h>
12
13PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name) {
14 AT_ASSERT(name.length() < DTYPE_NAME_LEN);
15 auto type = (PyTypeObject*)&THPDtypeType;
16 auto self = THPObjectPtr{type->tp_alloc(type, 0)};
17 if (!self)
18 throw python_error();
19 auto self_ = reinterpret_cast<THPDtype*>(self.get());
20 self_->scalar_type = scalar_type;
21 std::strncpy(self_->name, name.c_str(), DTYPE_NAME_LEN);
22 return self.release();
23}
24
25PyObject* THPDtype_is_floating_point(THPDtype* self, PyObject* noargs) {
26 if (at::isFloatingType(self->scalar_type)) {
27 Py_RETURN_TRUE;
28 } else {
29 Py_RETURN_FALSE;
30 }
31}
32
33PyObject* THPDtype_is_complex(THPDtype* self, PyObject* noargs) {
34 if (at::isComplexType(self->scalar_type)) {
35 Py_RETURN_TRUE;
36 } else {
37 Py_RETURN_FALSE;
38 }
39}
40
41PyObject* THPDtype_is_signed(THPDtype* self, PyObject* noargs) {
42 HANDLE_TH_ERRORS
43 if (at::isSignedType(self->scalar_type)) {
44 Py_RETURN_TRUE;
45 } else {
46 Py_RETURN_FALSE;
47 }
48 END_HANDLE_TH_ERRORS
49}
50
51PyObject* THPDtype_reduce(PyObject* _self, PyObject* noargs) {
52 /*
53 * For singletons, a string is returned. The string should be interpreted
54 * as the name of a global variable.
55 */
56 auto self = (THPDtype*)_self;
57 return THPUtils_packString(self->name);
58}
59
60typedef PyObject* (*getter)(PyObject*, void*);
61
62// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
63static struct PyGetSetDef THPDtype_properties[] = {
64 {"is_floating_point",
65 (getter)THPDtype_is_floating_point,
66 nullptr,
67 nullptr,
68 nullptr},
69 {"is_complex", (getter)THPDtype_is_complex, nullptr, nullptr, nullptr},
70 {"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr},
71 {nullptr}};
72
73// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
74static PyMethodDef THPDtype_methods[] = {
75 {"__reduce__", THPDtype_reduce, METH_NOARGS, nullptr},
76 {nullptr} /* Sentinel */
77};
78
79PyObject* THPDtype_repr(THPDtype* self) {
80 std::string name = self->name;
81 return THPUtils_packString("torch." + name);
82}
83
84PyTypeObject THPDtypeType = {
85 PyVarObject_HEAD_INIT(nullptr, 0) "torch.dtype", /* tp_name */
86 sizeof(THPDtype), /* tp_basicsize */
87 0, /* tp_itemsize */
88 nullptr, /* tp_dealloc */
89 0, /* tp_vectorcall_offset */
90 nullptr, /* tp_getattr */
91 nullptr, /* tp_setattr */
92 nullptr, /* tp_reserved */
93 (reprfunc)THPDtype_repr, /* tp_repr */
94 nullptr, /* tp_as_number */
95 nullptr, /* tp_as_sequence */
96 nullptr, /* tp_as_mapping */
97 nullptr, /* tp_hash */
98 nullptr, /* tp_call */
99 nullptr, /* tp_str */
100 nullptr, /* tp_getattro */
101 nullptr, /* tp_setattro */
102 nullptr, /* tp_as_buffer */
103 Py_TPFLAGS_DEFAULT, /* tp_flags */
104 nullptr, /* tp_doc */
105 nullptr, /* tp_traverse */
106 nullptr, /* tp_clear */
107 nullptr, /* tp_richcompare */
108 0, /* tp_weaklistoffset */
109 nullptr, /* tp_iter */
110 nullptr, /* tp_iternext */
111 THPDtype_methods, /* tp_methods */
112 nullptr, /* tp_members */
113 THPDtype_properties, /* tp_getset */
114 nullptr, /* tp_base */
115 nullptr, /* tp_dict */
116 nullptr, /* tp_descr_get */
117 nullptr, /* tp_descr_set */
118 0, /* tp_dictoffset */
119 nullptr, /* tp_init */
120 nullptr, /* tp_alloc */
121 nullptr, /* tp_new */
122};
123
124void THPDtype_init(PyObject* module) {
125 // Set a __dict__ with `__module__` = `torch`. This means
126 // `__module__` value will be inherited by instances
127 // (i.e. `torch.float32.__module__ == "torch"`). This will prevent
128 // Pickle from having to search all of sys.modules in order to find
129 // the module when pickling a dtype instance.
130 //
131 // We have to do this in C++ because extension types are not mutable
132 // from Python code.
133 //
134 // See https://github.com/pytorch/pytorch/issues/65077
135 TORCH_INTERNAL_ASSERT(THPDtypeType.tp_dict == nullptr);
136 auto dict = THPObjectPtr(PyDict_New());
137 if (!dict)
138 throw python_error();
139 auto torch = THPUtils_packString("torch");
140 if (!torch)
141 throw python_error();
142 if (PyDict_SetItemString(dict, "__module__", torch) < 0) {
143 throw python_error();
144 }
145 THPDtypeType.tp_dict = dict.release();
146
147 if (PyType_Ready(&THPDtypeType) < 0) {
148 throw python_error();
149 }
150 Py_INCREF(&THPDtypeType);
151 if (PyModule_AddObject(module, "dtype", (PyObject*)&THPDtypeType) != 0) {
152 throw python_error();
153 }
154}
155