1 | #include <torch/csrc/Dtype.h> |
2 | |
3 | #include <structmember.h> |
4 | #include <torch/csrc/Exceptions.h> |
5 | #include <torch/csrc/utils/object_ptr.h> |
6 | #include <torch/csrc/utils/python_strings.h> |
7 | #include <torch/csrc/utils/tensor_dtypes.h> |
8 | #include <torch/csrc/utils/tensor_types.h> |
9 | #include <cstring> |
10 | |
11 | #include <torch/csrc/Exceptions.h> |
12 | |
13 | PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name) { |
14 | AT_ASSERT(name.length() < DTYPE_NAME_LEN); |
15 | auto type = (PyTypeObject*)&THPDtypeType; |
16 | auto self = THPObjectPtr{type->tp_alloc(type, 0)}; |
17 | if (!self) |
18 | throw python_error(); |
19 | auto self_ = reinterpret_cast<THPDtype*>(self.get()); |
20 | self_->scalar_type = scalar_type; |
21 | std::strncpy(self_->name, name.c_str(), DTYPE_NAME_LEN); |
22 | return self.release(); |
23 | } |
24 | |
25 | PyObject* THPDtype_is_floating_point(THPDtype* self, PyObject* noargs) { |
26 | if (at::isFloatingType(self->scalar_type)) { |
27 | Py_RETURN_TRUE; |
28 | } else { |
29 | Py_RETURN_FALSE; |
30 | } |
31 | } |
32 | |
33 | PyObject* THPDtype_is_complex(THPDtype* self, PyObject* noargs) { |
34 | if (at::isComplexType(self->scalar_type)) { |
35 | Py_RETURN_TRUE; |
36 | } else { |
37 | Py_RETURN_FALSE; |
38 | } |
39 | } |
40 | |
41 | PyObject* THPDtype_is_signed(THPDtype* self, PyObject* noargs) { |
42 | HANDLE_TH_ERRORS |
43 | if (at::isSignedType(self->scalar_type)) { |
44 | Py_RETURN_TRUE; |
45 | } else { |
46 | Py_RETURN_FALSE; |
47 | } |
48 | END_HANDLE_TH_ERRORS |
49 | } |
50 | |
51 | PyObject* THPDtype_reduce(PyObject* _self, PyObject* noargs) { |
52 | /* |
53 | * For singletons, a string is returned. The string should be interpreted |
54 | * as the name of a global variable. |
55 | */ |
56 | auto self = (THPDtype*)_self; |
57 | return THPUtils_packString(self->name); |
58 | } |
59 | |
60 | typedef PyObject* (*getter)(PyObject*, void*); |
61 | |
62 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
63 | static struct PyGetSetDef THPDtype_properties[] = { |
64 | {"is_floating_point" , |
65 | (getter)THPDtype_is_floating_point, |
66 | nullptr, |
67 | nullptr, |
68 | nullptr}, |
69 | {"is_complex" , (getter)THPDtype_is_complex, nullptr, nullptr, nullptr}, |
70 | {"is_signed" , (getter)THPDtype_is_signed, nullptr, nullptr, nullptr}, |
71 | {nullptr}}; |
72 | |
73 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
74 | static PyMethodDef THPDtype_methods[] = { |
75 | {"__reduce__" , THPDtype_reduce, METH_NOARGS, nullptr}, |
76 | {nullptr} /* Sentinel */ |
77 | }; |
78 | |
79 | PyObject* THPDtype_repr(THPDtype* self) { |
80 | std::string name = self->name; |
81 | return THPUtils_packString("torch." + name); |
82 | } |
83 | |
84 | PyTypeObject THPDtypeType = { |
85 | PyVarObject_HEAD_INIT(nullptr, 0) "torch.dtype" , /* tp_name */ |
86 | sizeof(THPDtype), /* tp_basicsize */ |
87 | 0, /* tp_itemsize */ |
88 | nullptr, /* tp_dealloc */ |
89 | 0, /* tp_vectorcall_offset */ |
90 | nullptr, /* tp_getattr */ |
91 | nullptr, /* tp_setattr */ |
92 | nullptr, /* tp_reserved */ |
93 | (reprfunc)THPDtype_repr, /* tp_repr */ |
94 | nullptr, /* tp_as_number */ |
95 | nullptr, /* tp_as_sequence */ |
96 | nullptr, /* tp_as_mapping */ |
97 | nullptr, /* tp_hash */ |
98 | nullptr, /* tp_call */ |
99 | nullptr, /* tp_str */ |
100 | nullptr, /* tp_getattro */ |
101 | nullptr, /* tp_setattro */ |
102 | nullptr, /* tp_as_buffer */ |
103 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
104 | nullptr, /* tp_doc */ |
105 | nullptr, /* tp_traverse */ |
106 | nullptr, /* tp_clear */ |
107 | nullptr, /* tp_richcompare */ |
108 | 0, /* tp_weaklistoffset */ |
109 | nullptr, /* tp_iter */ |
110 | nullptr, /* tp_iternext */ |
111 | THPDtype_methods, /* tp_methods */ |
112 | nullptr, /* tp_members */ |
113 | THPDtype_properties, /* tp_getset */ |
114 | nullptr, /* tp_base */ |
115 | nullptr, /* tp_dict */ |
116 | nullptr, /* tp_descr_get */ |
117 | nullptr, /* tp_descr_set */ |
118 | 0, /* tp_dictoffset */ |
119 | nullptr, /* tp_init */ |
120 | nullptr, /* tp_alloc */ |
121 | nullptr, /* tp_new */ |
122 | }; |
123 | |
124 | void THPDtype_init(PyObject* module) { |
125 | // Set a __dict__ with `__module__` = `torch`. This means |
126 | // `__module__` value will be inherited by instances |
127 | // (i.e. `torch.float32.__module__ == "torch"`). This will prevent |
128 | // Pickle from having to search all of sys.modules in order to find |
129 | // the module when pickling a dtype instance. |
130 | // |
131 | // We have to do this in C++ because extension types are not mutable |
132 | // from Python code. |
133 | // |
134 | // See https://github.com/pytorch/pytorch/issues/65077 |
135 | TORCH_INTERNAL_ASSERT(THPDtypeType.tp_dict == nullptr); |
136 | auto dict = THPObjectPtr(PyDict_New()); |
137 | if (!dict) |
138 | throw python_error(); |
139 | auto torch = THPUtils_packString("torch" ); |
140 | if (!torch) |
141 | throw python_error(); |
142 | if (PyDict_SetItemString(dict, "__module__" , torch) < 0) { |
143 | throw python_error(); |
144 | } |
145 | THPDtypeType.tp_dict = dict.release(); |
146 | |
147 | if (PyType_Ready(&THPDtypeType) < 0) { |
148 | throw python_error(); |
149 | } |
150 | Py_INCREF(&THPDtypeType); |
151 | if (PyModule_AddObject(module, "dtype" , (PyObject*)&THPDtypeType) != 0) { |
152 | throw python_error(); |
153 | } |
154 | } |
155 | |