1 | // types.UnionType -- used to represent e.g. Union[int, str], int | str |
2 | #include "Python.h" |
3 | #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK |
4 | #include "pycore_unionobject.h" |
5 | #include "structmember.h" |
6 | |
7 | |
8 | static PyObject *make_union(PyObject *); |
9 | |
10 | |
11 | typedef struct { |
12 | PyObject_HEAD |
13 | PyObject *args; |
14 | PyObject *parameters; |
15 | } unionobject; |
16 | |
17 | static void |
18 | unionobject_dealloc(PyObject *self) |
19 | { |
20 | unionobject *alias = (unionobject *)self; |
21 | |
22 | _PyObject_GC_UNTRACK(self); |
23 | |
24 | Py_XDECREF(alias->args); |
25 | Py_XDECREF(alias->parameters); |
26 | Py_TYPE(self)->tp_free(self); |
27 | } |
28 | |
29 | static int |
30 | union_traverse(PyObject *self, visitproc visit, void *arg) |
31 | { |
32 | unionobject *alias = (unionobject *)self; |
33 | Py_VISIT(alias->args); |
34 | Py_VISIT(alias->parameters); |
35 | return 0; |
36 | } |
37 | |
38 | static Py_hash_t |
39 | union_hash(PyObject *self) |
40 | { |
41 | unionobject *alias = (unionobject *)self; |
42 | PyObject *args = PyFrozenSet_New(alias->args); |
43 | if (args == NULL) { |
44 | return (Py_hash_t)-1; |
45 | } |
46 | Py_hash_t hash = PyObject_Hash(args); |
47 | Py_DECREF(args); |
48 | return hash; |
49 | } |
50 | |
51 | static int |
52 | is_generic_alias_in_args(PyObject *args) |
53 | { |
54 | Py_ssize_t nargs = PyTuple_GET_SIZE(args); |
55 | for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) { |
56 | PyObject *arg = PyTuple_GET_ITEM(args, iarg); |
57 | if (_PyGenericAlias_Check(arg)) { |
58 | return 0; |
59 | } |
60 | } |
61 | return 1; |
62 | } |
63 | |
64 | static PyObject * |
65 | union_instancecheck(PyObject *self, PyObject *instance) |
66 | { |
67 | unionobject *alias = (unionobject *) self; |
68 | Py_ssize_t nargs = PyTuple_GET_SIZE(alias->args); |
69 | if (!is_generic_alias_in_args(alias->args)) { |
70 | PyErr_SetString(PyExc_TypeError, |
71 | "isinstance() argument 2 cannot contain a parameterized generic" ); |
72 | return NULL; |
73 | } |
74 | for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) { |
75 | PyObject *arg = PyTuple_GET_ITEM(alias->args, iarg); |
76 | if (PyType_Check(arg)) { |
77 | int res = PyObject_IsInstance(instance, arg); |
78 | if (res < 0) { |
79 | return NULL; |
80 | } |
81 | if (res) { |
82 | Py_RETURN_TRUE; |
83 | } |
84 | } |
85 | } |
86 | Py_RETURN_FALSE; |
87 | } |
88 | |
89 | static PyObject * |
90 | union_subclasscheck(PyObject *self, PyObject *instance) |
91 | { |
92 | if (!PyType_Check(instance)) { |
93 | PyErr_SetString(PyExc_TypeError, "issubclass() arg 1 must be a class" ); |
94 | return NULL; |
95 | } |
96 | unionobject *alias = (unionobject *)self; |
97 | if (!is_generic_alias_in_args(alias->args)) { |
98 | PyErr_SetString(PyExc_TypeError, |
99 | "issubclass() argument 2 cannot contain a parameterized generic" ); |
100 | return NULL; |
101 | } |
102 | Py_ssize_t nargs = PyTuple_GET_SIZE(alias->args); |
103 | for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) { |
104 | PyObject *arg = PyTuple_GET_ITEM(alias->args, iarg); |
105 | if (PyType_Check(arg)) { |
106 | int res = PyObject_IsSubclass(instance, arg); |
107 | if (res < 0) { |
108 | return NULL; |
109 | } |
110 | if (res) { |
111 | Py_RETURN_TRUE; |
112 | } |
113 | } |
114 | } |
115 | Py_RETURN_FALSE; |
116 | } |
117 | |
118 | static PyObject * |
119 | union_richcompare(PyObject *a, PyObject *b, int op) |
120 | { |
121 | if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) { |
122 | Py_RETURN_NOTIMPLEMENTED; |
123 | } |
124 | |
125 | PyObject *a_set = PySet_New(((unionobject*)a)->args); |
126 | if (a_set == NULL) { |
127 | return NULL; |
128 | } |
129 | PyObject *b_set = PySet_New(((unionobject*)b)->args); |
130 | if (b_set == NULL) { |
131 | Py_DECREF(a_set); |
132 | return NULL; |
133 | } |
134 | PyObject *result = PyObject_RichCompare(a_set, b_set, op); |
135 | Py_DECREF(b_set); |
136 | Py_DECREF(a_set); |
137 | return result; |
138 | } |
139 | |
140 | static PyObject* |
141 | flatten_args(PyObject* args) |
142 | { |
143 | Py_ssize_t arg_length = PyTuple_GET_SIZE(args); |
144 | Py_ssize_t total_args = 0; |
145 | // Get number of total args once it's flattened. |
146 | for (Py_ssize_t i = 0; i < arg_length; i++) { |
147 | PyObject *arg = PyTuple_GET_ITEM(args, i); |
148 | if (_PyUnion_Check(arg)) { |
149 | total_args += PyTuple_GET_SIZE(((unionobject*) arg)->args); |
150 | } else { |
151 | total_args++; |
152 | } |
153 | } |
154 | // Create new tuple of flattened args. |
155 | PyObject *flattened_args = PyTuple_New(total_args); |
156 | if (flattened_args == NULL) { |
157 | return NULL; |
158 | } |
159 | Py_ssize_t pos = 0; |
160 | for (Py_ssize_t i = 0; i < arg_length; i++) { |
161 | PyObject *arg = PyTuple_GET_ITEM(args, i); |
162 | if (_PyUnion_Check(arg)) { |
163 | PyObject* nested_args = ((unionobject*)arg)->args; |
164 | Py_ssize_t nested_arg_length = PyTuple_GET_SIZE(nested_args); |
165 | for (Py_ssize_t j = 0; j < nested_arg_length; j++) { |
166 | PyObject* nested_arg = PyTuple_GET_ITEM(nested_args, j); |
167 | Py_INCREF(nested_arg); |
168 | PyTuple_SET_ITEM(flattened_args, pos, nested_arg); |
169 | pos++; |
170 | } |
171 | } else { |
172 | if (arg == Py_None) { |
173 | arg = (PyObject *)&_PyNone_Type; |
174 | } |
175 | Py_INCREF(arg); |
176 | PyTuple_SET_ITEM(flattened_args, pos, arg); |
177 | pos++; |
178 | } |
179 | } |
180 | assert(pos == total_args); |
181 | return flattened_args; |
182 | } |
183 | |
184 | static PyObject* |
185 | dedup_and_flatten_args(PyObject* args) |
186 | { |
187 | args = flatten_args(args); |
188 | if (args == NULL) { |
189 | return NULL; |
190 | } |
191 | Py_ssize_t arg_length = PyTuple_GET_SIZE(args); |
192 | PyObject *new_args = PyTuple_New(arg_length); |
193 | if (new_args == NULL) { |
194 | Py_DECREF(args); |
195 | return NULL; |
196 | } |
197 | // Add unique elements to an array. |
198 | Py_ssize_t added_items = 0; |
199 | for (Py_ssize_t i = 0; i < arg_length; i++) { |
200 | int is_duplicate = 0; |
201 | PyObject* i_element = PyTuple_GET_ITEM(args, i); |
202 | for (Py_ssize_t j = 0; j < added_items; j++) { |
203 | PyObject* j_element = PyTuple_GET_ITEM(new_args, j); |
204 | int is_ga = _PyGenericAlias_Check(i_element) && |
205 | _PyGenericAlias_Check(j_element); |
206 | // RichCompare to also deduplicate GenericAlias types (slower) |
207 | is_duplicate = is_ga ? PyObject_RichCompareBool(i_element, j_element, Py_EQ) |
208 | : i_element == j_element; |
209 | // Should only happen if RichCompare fails |
210 | if (is_duplicate < 0) { |
211 | Py_DECREF(args); |
212 | Py_DECREF(new_args); |
213 | return NULL; |
214 | } |
215 | if (is_duplicate) |
216 | break; |
217 | } |
218 | if (!is_duplicate) { |
219 | Py_INCREF(i_element); |
220 | PyTuple_SET_ITEM(new_args, added_items, i_element); |
221 | added_items++; |
222 | } |
223 | } |
224 | Py_DECREF(args); |
225 | _PyTuple_Resize(&new_args, added_items); |
226 | return new_args; |
227 | } |
228 | |
229 | static int |
230 | is_unionable(PyObject *obj) |
231 | { |
232 | return (obj == Py_None || |
233 | PyType_Check(obj) || |
234 | _PyGenericAlias_Check(obj) || |
235 | _PyUnion_Check(obj)); |
236 | } |
237 | |
238 | PyObject * |
239 | _Py_union_type_or(PyObject* self, PyObject* other) |
240 | { |
241 | if (!is_unionable(self) || !is_unionable(other)) { |
242 | Py_RETURN_NOTIMPLEMENTED; |
243 | } |
244 | |
245 | PyObject *tuple = PyTuple_Pack(2, self, other); |
246 | if (tuple == NULL) { |
247 | return NULL; |
248 | } |
249 | |
250 | PyObject *new_union = make_union(tuple); |
251 | Py_DECREF(tuple); |
252 | return new_union; |
253 | } |
254 | |
255 | static int |
256 | union_repr_item(_PyUnicodeWriter *writer, PyObject *p) |
257 | { |
258 | _Py_IDENTIFIER(__module__); |
259 | _Py_IDENTIFIER(__qualname__); |
260 | _Py_IDENTIFIER(__origin__); |
261 | _Py_IDENTIFIER(__args__); |
262 | PyObject *qualname = NULL; |
263 | PyObject *module = NULL; |
264 | PyObject *tmp; |
265 | PyObject *r = NULL; |
266 | int err; |
267 | |
268 | if (p == (PyObject *)&_PyNone_Type) { |
269 | return _PyUnicodeWriter_WriteASCIIString(writer, "None" , 4); |
270 | } |
271 | |
272 | if (_PyObject_LookupAttrId(p, &PyId___origin__, &tmp) < 0) { |
273 | goto exit; |
274 | } |
275 | |
276 | if (tmp) { |
277 | Py_DECREF(tmp); |
278 | if (_PyObject_LookupAttrId(p, &PyId___args__, &tmp) < 0) { |
279 | goto exit; |
280 | } |
281 | if (tmp) { |
282 | // It looks like a GenericAlias |
283 | Py_DECREF(tmp); |
284 | goto use_repr; |
285 | } |
286 | } |
287 | |
288 | if (_PyObject_LookupAttrId(p, &PyId___qualname__, &qualname) < 0) { |
289 | goto exit; |
290 | } |
291 | if (qualname == NULL) { |
292 | goto use_repr; |
293 | } |
294 | if (_PyObject_LookupAttrId(p, &PyId___module__, &module) < 0) { |
295 | goto exit; |
296 | } |
297 | if (module == NULL || module == Py_None) { |
298 | goto use_repr; |
299 | } |
300 | |
301 | // Looks like a class |
302 | if (PyUnicode_Check(module) && |
303 | _PyUnicode_EqualToASCIIString(module, "builtins" )) |
304 | { |
305 | // builtins don't need a module name |
306 | r = PyObject_Str(qualname); |
307 | goto exit; |
308 | } |
309 | else { |
310 | r = PyUnicode_FromFormat("%S.%S" , module, qualname); |
311 | goto exit; |
312 | } |
313 | |
314 | use_repr: |
315 | r = PyObject_Repr(p); |
316 | exit: |
317 | Py_XDECREF(qualname); |
318 | Py_XDECREF(module); |
319 | if (r == NULL) { |
320 | return -1; |
321 | } |
322 | err = _PyUnicodeWriter_WriteStr(writer, r); |
323 | Py_DECREF(r); |
324 | return err; |
325 | } |
326 | |
327 | static PyObject * |
328 | union_repr(PyObject *self) |
329 | { |
330 | unionobject *alias = (unionobject *)self; |
331 | Py_ssize_t len = PyTuple_GET_SIZE(alias->args); |
332 | |
333 | _PyUnicodeWriter writer; |
334 | _PyUnicodeWriter_Init(&writer); |
335 | for (Py_ssize_t i = 0; i < len; i++) { |
336 | if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | " , 3) < 0) { |
337 | goto error; |
338 | } |
339 | PyObject *p = PyTuple_GET_ITEM(alias->args, i); |
340 | if (union_repr_item(&writer, p) < 0) { |
341 | goto error; |
342 | } |
343 | } |
344 | return _PyUnicodeWriter_Finish(&writer); |
345 | error: |
346 | _PyUnicodeWriter_Dealloc(&writer); |
347 | return NULL; |
348 | } |
349 | |
350 | static PyMemberDef union_members[] = { |
351 | {"__args__" , T_OBJECT, offsetof(unionobject, args), READONLY}, |
352 | {0} |
353 | }; |
354 | |
355 | static PyMethodDef union_methods[] = { |
356 | {"__instancecheck__" , union_instancecheck, METH_O}, |
357 | {"__subclasscheck__" , union_subclasscheck, METH_O}, |
358 | {0}}; |
359 | |
360 | |
361 | static PyObject * |
362 | union_getitem(PyObject *self, PyObject *item) |
363 | { |
364 | unionobject *alias = (unionobject *)self; |
365 | // Populate __parameters__ if needed. |
366 | if (alias->parameters == NULL) { |
367 | alias->parameters = _Py_make_parameters(alias->args); |
368 | if (alias->parameters == NULL) { |
369 | return NULL; |
370 | } |
371 | } |
372 | |
373 | PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item); |
374 | if (newargs == NULL) { |
375 | return NULL; |
376 | } |
377 | |
378 | PyObject *res; |
379 | Py_ssize_t nargs = PyTuple_GET_SIZE(newargs); |
380 | if (nargs == 0) { |
381 | res = make_union(newargs); |
382 | } |
383 | else { |
384 | res = PyTuple_GET_ITEM(newargs, 0); |
385 | Py_INCREF(res); |
386 | for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) { |
387 | PyObject *arg = PyTuple_GET_ITEM(newargs, iarg); |
388 | Py_SETREF(res, PyNumber_Or(res, arg)); |
389 | if (res == NULL) { |
390 | break; |
391 | } |
392 | } |
393 | } |
394 | Py_DECREF(newargs); |
395 | return res; |
396 | } |
397 | |
398 | static PyMappingMethods union_as_mapping = { |
399 | .mp_subscript = union_getitem, |
400 | }; |
401 | |
402 | static PyObject * |
403 | union_parameters(PyObject *self, void *Py_UNUSED(unused)) |
404 | { |
405 | unionobject *alias = (unionobject *)self; |
406 | if (alias->parameters == NULL) { |
407 | alias->parameters = _Py_make_parameters(alias->args); |
408 | if (alias->parameters == NULL) { |
409 | return NULL; |
410 | } |
411 | } |
412 | Py_INCREF(alias->parameters); |
413 | return alias->parameters; |
414 | } |
415 | |
416 | static PyGetSetDef union_properties[] = { |
417 | {"__parameters__" , union_parameters, (setter)NULL, "Type variables in the types.UnionType." , NULL}, |
418 | {0} |
419 | }; |
420 | |
421 | static PyNumberMethods union_as_number = { |
422 | .nb_or = _Py_union_type_or, // Add __or__ function |
423 | }; |
424 | |
425 | static const char* const cls_attrs[] = { |
426 | "__module__" , // Required for compatibility with typing module |
427 | NULL, |
428 | }; |
429 | |
430 | static PyObject * |
431 | union_getattro(PyObject *self, PyObject *name) |
432 | { |
433 | unionobject *alias = (unionobject *)self; |
434 | if (PyUnicode_Check(name)) { |
435 | for (const char * const *p = cls_attrs; ; p++) { |
436 | if (*p == NULL) { |
437 | break; |
438 | } |
439 | if (_PyUnicode_EqualToASCIIString(name, *p)) { |
440 | return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name); |
441 | } |
442 | } |
443 | } |
444 | return PyObject_GenericGetAttr(self, name); |
445 | } |
446 | |
447 | PyTypeObject _PyUnion_Type = { |
448 | PyVarObject_HEAD_INIT(&PyType_Type, 0) |
449 | .tp_name = "types.UnionType" , |
450 | .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n" |
451 | "\n" |
452 | "E.g. for int | str" ), |
453 | .tp_basicsize = sizeof(unionobject), |
454 | .tp_dealloc = unionobject_dealloc, |
455 | .tp_alloc = PyType_GenericAlloc, |
456 | .tp_free = PyObject_GC_Del, |
457 | .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, |
458 | .tp_traverse = union_traverse, |
459 | .tp_hash = union_hash, |
460 | .tp_getattro = union_getattro, |
461 | .tp_members = union_members, |
462 | .tp_methods = union_methods, |
463 | .tp_richcompare = union_richcompare, |
464 | .tp_as_mapping = &union_as_mapping, |
465 | .tp_as_number = &union_as_number, |
466 | .tp_repr = union_repr, |
467 | .tp_getset = union_properties, |
468 | }; |
469 | |
470 | static PyObject * |
471 | make_union(PyObject *args) |
472 | { |
473 | assert(PyTuple_CheckExact(args)); |
474 | |
475 | args = dedup_and_flatten_args(args); |
476 | if (args == NULL) { |
477 | return NULL; |
478 | } |
479 | if (PyTuple_GET_SIZE(args) == 1) { |
480 | PyObject *result1 = PyTuple_GET_ITEM(args, 0); |
481 | Py_INCREF(result1); |
482 | Py_DECREF(args); |
483 | return result1; |
484 | } |
485 | |
486 | unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type); |
487 | if (result == NULL) { |
488 | Py_DECREF(args); |
489 | return NULL; |
490 | } |
491 | |
492 | result->parameters = NULL; |
493 | result->args = args; |
494 | _PyObject_GC_TRACK(result); |
495 | return (PyObject*)result; |
496 | } |
497 | |