1 | #include <torch/csrc/TypeInfo.h> |
2 | |
3 | #include <torch/csrc/Exceptions.h> |
4 | #include <torch/csrc/utils/object_ptr.h> |
5 | #include <torch/csrc/utils/pybind.h> |
6 | #include <torch/csrc/utils/python_arg_parser.h> |
7 | #include <torch/csrc/utils/python_numbers.h> |
8 | #include <torch/csrc/utils/python_strings.h> |
9 | #include <torch/csrc/utils/tensor_dtypes.h> |
10 | |
11 | #include <c10/util/Exception.h> |
12 | |
13 | #include <structmember.h> |
14 | #include <cstring> |
15 | #include <limits> |
16 | #include <sstream> |
17 | |
18 | PyObject* THPFInfo_New(const at::ScalarType& type) { |
19 | auto finfo = (PyTypeObject*)&THPFInfoType; |
20 | auto self = THPObjectPtr{finfo->tp_alloc(finfo, 0)}; |
21 | if (!self) |
22 | throw python_error(); |
23 | auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get()); |
24 | self_->type = c10::toRealValueType(type); |
25 | return self.release(); |
26 | } |
27 | |
28 | PyObject* THPIInfo_New(const at::ScalarType& type) { |
29 | auto iinfo = (PyTypeObject*)&THPIInfoType; |
30 | auto self = THPObjectPtr{iinfo->tp_alloc(iinfo, 0)}; |
31 | if (!self) |
32 | throw python_error(); |
33 | auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get()); |
34 | self_->type = type; |
35 | return self.release(); |
36 | } |
37 | |
38 | PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) { |
39 | HANDLE_TH_ERRORS |
40 | static torch::PythonArgParser parser({ |
41 | "finfo(ScalarType type)" , |
42 | "finfo()" , |
43 | }); |
44 | |
45 | torch::ParsedArgs<1> parsed_args; |
46 | auto r = parser.parse(args, kwargs, parsed_args); |
47 | TORCH_CHECK(r.idx < 2, "Not a type" ); |
48 | at::ScalarType scalar_type; |
49 | if (r.idx == 1) { |
50 | scalar_type = torch::tensors::get_default_scalar_type(); |
51 | // The default tensor type can only be set to a floating point type/ |
52 | AT_ASSERT(at::isFloatingType(scalar_type)); |
53 | } else { |
54 | scalar_type = r.scalartype(0); |
55 | if (!at::isFloatingType(scalar_type) && !at::isComplexType(scalar_type)) { |
56 | return PyErr_Format( |
57 | PyExc_TypeError, |
58 | "torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'" , |
59 | type->tp_name); |
60 | } |
61 | } |
62 | return THPFInfo_New(scalar_type); |
63 | END_HANDLE_TH_ERRORS |
64 | } |
65 | |
66 | PyObject* THPIInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) { |
67 | HANDLE_TH_ERRORS |
68 | static torch::PythonArgParser parser({ |
69 | "iinfo(ScalarType type)" , |
70 | }); |
71 | torch::ParsedArgs<1> parsed_args; |
72 | auto r = parser.parse(args, kwargs, parsed_args); |
73 | TORCH_CHECK(r.idx == 0, "Not a type" ); |
74 | |
75 | at::ScalarType scalar_type = r.scalartype(0); |
76 | if (scalar_type == at::ScalarType::Bool) { |
77 | return PyErr_Format( |
78 | PyExc_TypeError, "torch.bool is not supported by torch.iinfo" ); |
79 | } |
80 | if (!at::isIntegralType(scalar_type, /*includeBool=*/false) && |
81 | !at::isQIntType(scalar_type)) { |
82 | return PyErr_Format( |
83 | PyExc_TypeError, |
84 | "torch.iinfo() requires an integer input type. Use torch.finfo to handle '%s'" , |
85 | type->tp_name); |
86 | } |
87 | return THPIInfo_New(scalar_type); |
88 | END_HANDLE_TH_ERRORS |
89 | } |
90 | |
91 | PyObject* THPDTypeInfo_compare(THPDTypeInfo* a, THPDTypeInfo* b, int op) { |
92 | switch (op) { |
93 | case Py_EQ: |
94 | if (a->type == b->type) { |
95 | Py_RETURN_TRUE; |
96 | } else { |
97 | Py_RETURN_FALSE; |
98 | } |
99 | case Py_NE: |
100 | if (a->type != b->type) { |
101 | Py_RETURN_TRUE; |
102 | } else { |
103 | Py_RETURN_FALSE; |
104 | } |
105 | } |
106 | return Py_INCREF(Py_NotImplemented), Py_NotImplemented; |
107 | } |
108 | |
109 | static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) { |
110 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) |
111 | int64_t bits = elementSize(self->type) * 8; |
112 | return THPUtils_packInt64(bits); |
113 | } |
114 | |
115 | static PyObject* THPFInfo_eps(THPFInfo* self, void*) { |
116 | return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( |
117 | at::kHalf, at::ScalarType::BFloat16, self->type, "epsilon" , [] { |
118 | return PyFloat_FromDouble( |
119 | std::numeric_limits< |
120 | at::scalar_value_type<scalar_t>::type>::epsilon()); |
121 | }); |
122 | } |
123 | |
124 | static PyObject* THPFInfo_max(THPFInfo* self, void*) { |
125 | return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( |
126 | at::kHalf, at::ScalarType::BFloat16, self->type, "max" , [] { |
127 | return PyFloat_FromDouble( |
128 | std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max()); |
129 | }); |
130 | } |
131 | |
132 | static PyObject* THPFInfo_min(THPFInfo* self, void*) { |
133 | return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( |
134 | at::kHalf, at::ScalarType::BFloat16, self->type, "lowest" , [] { |
135 | return PyFloat_FromDouble( |
136 | std::numeric_limits< |
137 | at::scalar_value_type<scalar_t>::type>::lowest()); |
138 | }); |
139 | } |
140 | |
141 | static PyObject* THPIInfo_max(THPIInfo* self, void*) { |
142 | if (at::isIntegralType(self->type, /*includeBool=*/false)) { |
143 | return AT_DISPATCH_INTEGRAL_TYPES(self->type, "max" , [] { |
144 | return THPUtils_packInt64(std::numeric_limits<scalar_t>::max()); |
145 | }); |
146 | } |
147 | // Quantized Type |
148 | return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "max" , [] { |
149 | return THPUtils_packInt64(std::numeric_limits<underlying_t>::max()); |
150 | }); |
151 | } |
152 | |
153 | static PyObject* THPIInfo_min(THPIInfo* self, void*) { |
154 | if (at::isIntegralType(self->type, /*includeBool=*/false)) { |
155 | return AT_DISPATCH_INTEGRAL_TYPES(self->type, "min" , [] { |
156 | return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest()); |
157 | }); |
158 | } |
159 | // Quantized Type |
160 | return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "min" , [] { |
161 | return THPUtils_packInt64(std::numeric_limits<underlying_t>::lowest()); |
162 | }); |
163 | } |
164 | |
165 | static PyObject* THPIInfo_dtype(THPIInfo* self, void*) { |
166 | std::string primary_name, legacy_name; |
167 | std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(self->type); |
168 | // NOLINTNEXTLINE(clang-diagnostic-unused-local-typedef) |
169 | return AT_DISPATCH_INTEGRAL_TYPES(self->type, "dtype" , [primary_name] { |
170 | return PyUnicode_FromString((char*)primary_name.data()); |
171 | }); |
172 | } |
173 | |
174 | static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) { |
175 | return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( |
176 | at::kHalf, at::ScalarType::BFloat16, self->type, "min" , [] { |
177 | return PyFloat_FromDouble( |
178 | std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min()); |
179 | }); |
180 | } |
181 | |
182 | static PyObject* THPFInfo_tiny(THPFInfo* self, void*) { |
183 | // see gh-70909, essentially the array_api prefers smallest_normal over tiny |
184 | return THPFInfo_smallest_normal(self, nullptr); |
185 | } |
186 | |
187 | static PyObject* THPFInfo_resolution(THPFInfo* self, void*) { |
188 | return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( |
189 | at::kHalf, at::ScalarType::BFloat16, self->type, "digits10" , [] { |
190 | return PyFloat_FromDouble(std::pow( |
191 | 10, |
192 | -std::numeric_limits< |
193 | at::scalar_value_type<scalar_t>::type>::digits10)); |
194 | }); |
195 | } |
196 | |
197 | static PyObject* THPFInfo_dtype(THPFInfo* self, void*) { |
198 | std::string primary_name, legacy_name; |
199 | std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(self->type); |
200 | // NOLINTNEXTLINE(clang-diagnostic-unused-local-typedef) |
201 | return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( |
202 | at::kHalf, at::ScalarType::BFloat16, self->type, "dtype" , [primary_name] { |
203 | return PyUnicode_FromString((char*)primary_name.data()); |
204 | }); |
205 | } |
206 | |
207 | PyObject* THPFInfo_str(THPFInfo* self) { |
208 | std::ostringstream oss; |
209 | oss << "finfo(resolution=" |
210 | << PyFloat_AsDouble(THPFInfo_resolution(self, nullptr)); |
211 | oss << ", min=" << PyFloat_AsDouble(THPFInfo_min(self, nullptr)); |
212 | oss << ", max=" << PyFloat_AsDouble(THPFInfo_max(self, nullptr)); |
213 | oss << ", eps=" << PyFloat_AsDouble(THPFInfo_eps(self, nullptr)); |
214 | oss << ", smallest_normal=" |
215 | << PyFloat_AsDouble(THPFInfo_smallest_normal(self, nullptr)); |
216 | oss << ", tiny=" << PyFloat_AsDouble(THPFInfo_tiny(self, nullptr)); |
217 | oss << ", dtype=" << PyUnicode_AsUTF8(THPFInfo_dtype(self, nullptr)) << ")" ; |
218 | |
219 | return THPUtils_packString(oss.str().c_str()); |
220 | } |
221 | |
222 | PyObject* THPIInfo_str(THPIInfo* self) { |
223 | std::ostringstream oss; |
224 | |
225 | oss << "iinfo(min=" << PyLong_AsDouble(THPIInfo_min(self, nullptr)); |
226 | oss << ", max=" << PyLong_AsDouble(THPIInfo_max(self, nullptr)); |
227 | oss << ", dtype=" << PyUnicode_AsUTF8(THPIInfo_dtype(self, nullptr)) << ")" ; |
228 | |
229 | return THPUtils_packString(oss.str().c_str()); |
230 | } |
231 | |
232 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays) |
233 | static struct PyGetSetDef THPFInfo_properties[] = { |
234 | {"bits" , (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr}, |
235 | {"eps" , (getter)THPFInfo_eps, nullptr, nullptr, nullptr}, |
236 | {"max" , (getter)THPFInfo_max, nullptr, nullptr, nullptr}, |
237 | {"min" , (getter)THPFInfo_min, nullptr, nullptr, nullptr}, |
238 | {"smallest_normal" , |
239 | (getter)THPFInfo_smallest_normal, |
240 | nullptr, |
241 | nullptr, |
242 | nullptr}, |
243 | {"tiny" , (getter)THPFInfo_tiny, nullptr, nullptr, nullptr}, |
244 | {"resolution" , (getter)THPFInfo_resolution, nullptr, nullptr, nullptr}, |
245 | {"dtype" , (getter)THPFInfo_dtype, nullptr, nullptr, nullptr}, |
246 | {nullptr}}; |
247 | |
248 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays) |
249 | static PyMethodDef THPFInfo_methods[] = { |
250 | {nullptr} /* Sentinel */ |
251 | }; |
252 | |
253 | PyTypeObject THPFInfoType = { |
254 | PyVarObject_HEAD_INIT(nullptr, 0) "torch.finfo" , /* tp_name */ |
255 | sizeof(THPFInfo), /* tp_basicsize */ |
256 | 0, /* tp_itemsize */ |
257 | nullptr, /* tp_dealloc */ |
258 | 0, /* tp_vectorcall_offset */ |
259 | nullptr, /* tp_getattr */ |
260 | nullptr, /* tp_setattr */ |
261 | nullptr, /* tp_reserved */ |
262 | (reprfunc)THPFInfo_str, /* tp_repr */ |
263 | nullptr, /* tp_as_number */ |
264 | nullptr, /* tp_as_sequence */ |
265 | nullptr, /* tp_as_mapping */ |
266 | nullptr, /* tp_hash */ |
267 | nullptr, /* tp_call */ |
268 | (reprfunc)THPFInfo_str, /* tp_str */ |
269 | nullptr, /* tp_getattro */ |
270 | nullptr, /* tp_setattro */ |
271 | nullptr, /* tp_as_buffer */ |
272 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
273 | nullptr, /* tp_doc */ |
274 | nullptr, /* tp_traverse */ |
275 | nullptr, /* tp_clear */ |
276 | (richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */ |
277 | 0, /* tp_weaklistoffset */ |
278 | nullptr, /* tp_iter */ |
279 | nullptr, /* tp_iternext */ |
280 | THPFInfo_methods, /* tp_methods */ |
281 | nullptr, /* tp_members */ |
282 | THPFInfo_properties, /* tp_getset */ |
283 | nullptr, /* tp_base */ |
284 | nullptr, /* tp_dict */ |
285 | nullptr, /* tp_descr_get */ |
286 | nullptr, /* tp_descr_set */ |
287 | 0, /* tp_dictoffset */ |
288 | nullptr, /* tp_init */ |
289 | nullptr, /* tp_alloc */ |
290 | THPFInfo_pynew, /* tp_new */ |
291 | }; |
292 | |
293 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays) |
294 | static struct PyGetSetDef THPIInfo_properties[] = { |
295 | {"bits" , (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr}, |
296 | {"max" , (getter)THPIInfo_max, nullptr, nullptr, nullptr}, |
297 | {"min" , (getter)THPIInfo_min, nullptr, nullptr, nullptr}, |
298 | {"dtype" , (getter)THPIInfo_dtype, nullptr, nullptr, nullptr}, |
299 | {nullptr}}; |
300 | |
301 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays) |
302 | static PyMethodDef THPIInfo_methods[] = { |
303 | {nullptr} /* Sentinel */ |
304 | }; |
305 | |
306 | PyTypeObject THPIInfoType = { |
307 | PyVarObject_HEAD_INIT(nullptr, 0) "torch.iinfo" , /* tp_name */ |
308 | sizeof(THPIInfo), /* tp_basicsize */ |
309 | 0, /* tp_itemsize */ |
310 | nullptr, /* tp_dealloc */ |
311 | 0, /* tp_vectorcall_offset */ |
312 | nullptr, /* tp_getattr */ |
313 | nullptr, /* tp_setattr */ |
314 | nullptr, /* tp_reserved */ |
315 | (reprfunc)THPIInfo_str, /* tp_repr */ |
316 | nullptr, /* tp_as_number */ |
317 | nullptr, /* tp_as_sequence */ |
318 | nullptr, /* tp_as_mapping */ |
319 | nullptr, /* tp_hash */ |
320 | nullptr, /* tp_call */ |
321 | (reprfunc)THPIInfo_str, /* tp_str */ |
322 | nullptr, /* tp_getattro */ |
323 | nullptr, /* tp_setattro */ |
324 | nullptr, /* tp_as_buffer */ |
325 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
326 | nullptr, /* tp_doc */ |
327 | nullptr, /* tp_traverse */ |
328 | nullptr, /* tp_clear */ |
329 | (richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */ |
330 | 0, /* tp_weaklistoffset */ |
331 | nullptr, /* tp_iter */ |
332 | nullptr, /* tp_iternext */ |
333 | THPIInfo_methods, /* tp_methods */ |
334 | nullptr, /* tp_members */ |
335 | THPIInfo_properties, /* tp_getset */ |
336 | nullptr, /* tp_base */ |
337 | nullptr, /* tp_dict */ |
338 | nullptr, /* tp_descr_get */ |
339 | nullptr, /* tp_descr_set */ |
340 | 0, /* tp_dictoffset */ |
341 | nullptr, /* tp_init */ |
342 | nullptr, /* tp_alloc */ |
343 | THPIInfo_pynew, /* tp_new */ |
344 | }; |
345 | |
346 | void THPDTypeInfo_init(PyObject* module) { |
347 | if (PyType_Ready(&THPFInfoType) < 0) { |
348 | throw python_error(); |
349 | } |
350 | Py_INCREF(&THPFInfoType); |
351 | if (PyModule_AddObject(module, "finfo" , (PyObject*)&THPFInfoType) != 0) { |
352 | throw python_error(); |
353 | } |
354 | if (PyType_Ready(&THPIInfoType) < 0) { |
355 | throw python_error(); |
356 | } |
357 | Py_INCREF(&THPIInfoType); |
358 | if (PyModule_AddObject(module, "iinfo" , (PyObject*)&THPIInfoType) != 0) { |
359 | throw python_error(); |
360 | } |
361 | } |
362 | |