1 | /* AST Optimizer */ |
2 | #include "Python.h" |
3 | #include "pycore_ast.h" // _PyAST_GetDocString() |
4 | #include "pycore_compile.h" // _PyASTOptimizeState |
5 | #include "pycore_pystate.h" // _PyThreadState_GET() |
6 | |
7 | |
8 | static int |
9 | make_const(expr_ty node, PyObject *val, PyArena *arena) |
10 | { |
11 | // Even if no new value was calculated, make_const may still |
12 | // need to clear an error (e.g. for division by zero) |
13 | if (val == NULL) { |
14 | if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) { |
15 | return 0; |
16 | } |
17 | PyErr_Clear(); |
18 | return 1; |
19 | } |
20 | if (_PyArena_AddPyObject(arena, val) < 0) { |
21 | Py_DECREF(val); |
22 | return 0; |
23 | } |
24 | node->kind = Constant_kind; |
25 | node->v.Constant.kind = NULL; |
26 | node->v.Constant.value = val; |
27 | return 1; |
28 | } |
29 | |
30 | #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr))) |
31 | |
32 | static PyObject* |
33 | unary_not(PyObject *v) |
34 | { |
35 | int r = PyObject_IsTrue(v); |
36 | if (r < 0) |
37 | return NULL; |
38 | return PyBool_FromLong(!r); |
39 | } |
40 | |
41 | static int |
42 | fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) |
43 | { |
44 | expr_ty arg = node->v.UnaryOp.operand; |
45 | |
46 | if (arg->kind != Constant_kind) { |
47 | /* Fold not into comparison */ |
48 | if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind && |
49 | asdl_seq_LEN(arg->v.Compare.ops) == 1) { |
50 | /* Eq and NotEq are often implemented in terms of one another, so |
51 | folding not (self == other) into self != other breaks implementation |
52 | of !=. Detecting such cases doesn't seem worthwhile. |
53 | Python uses </> for 'is subset'/'is superset' operations on sets. |
54 | They don't satisfy not folding laws. */ |
55 | cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0); |
56 | switch (op) { |
57 | case Is: |
58 | op = IsNot; |
59 | break; |
60 | case IsNot: |
61 | op = Is; |
62 | break; |
63 | case In: |
64 | op = NotIn; |
65 | break; |
66 | case NotIn: |
67 | op = In; |
68 | break; |
69 | // The remaining comparison operators can't be safely inverted |
70 | case Eq: |
71 | case NotEq: |
72 | case Lt: |
73 | case LtE: |
74 | case Gt: |
75 | case GtE: |
76 | op = 0; // The AST enums leave "0" free as an "unused" marker |
77 | break; |
78 | // No default case, so the compiler will emit a warning if new |
79 | // comparison operators are added without being handled here |
80 | } |
81 | if (op) { |
82 | asdl_seq_SET(arg->v.Compare.ops, 0, op); |
83 | COPY_NODE(node, arg); |
84 | return 1; |
85 | } |
86 | } |
87 | return 1; |
88 | } |
89 | |
90 | typedef PyObject *(*unary_op)(PyObject*); |
91 | static const unary_op ops[] = { |
92 | [Invert] = PyNumber_Invert, |
93 | [Not] = unary_not, |
94 | [UAdd] = PyNumber_Positive, |
95 | [USub] = PyNumber_Negative, |
96 | }; |
97 | PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value); |
98 | return make_const(node, newval, arena); |
99 | } |
100 | |
101 | /* Check whether a collection doesn't containing too much items (including |
102 | subcollections). This protects from creating a constant that needs |
103 | too much time for calculating a hash. |
104 | "limit" is the maximal number of items. |
105 | Returns the negative number if the total number of items exceeds the |
106 | limit. Otherwise returns the limit minus the total number of items. |
107 | */ |
108 | |
109 | static Py_ssize_t |
110 | check_complexity(PyObject *obj, Py_ssize_t limit) |
111 | { |
112 | if (PyTuple_Check(obj)) { |
113 | Py_ssize_t i; |
114 | limit -= PyTuple_GET_SIZE(obj); |
115 | for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) { |
116 | limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit); |
117 | } |
118 | return limit; |
119 | } |
120 | else if (PyFrozenSet_Check(obj)) { |
121 | Py_ssize_t i = 0; |
122 | PyObject *item; |
123 | Py_hash_t hash; |
124 | limit -= PySet_GET_SIZE(obj); |
125 | while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) { |
126 | limit = check_complexity(item, limit); |
127 | } |
128 | } |
129 | return limit; |
130 | } |
131 | |
132 | #define MAX_INT_SIZE 128 /* bits */ |
133 | #define MAX_COLLECTION_SIZE 256 /* items */ |
134 | #define MAX_STR_SIZE 4096 /* characters */ |
135 | #define MAX_TOTAL_ITEMS 1024 /* including nested collections */ |
136 | |
137 | static PyObject * |
138 | safe_multiply(PyObject *v, PyObject *w) |
139 | { |
140 | if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) { |
141 | size_t vbits = _PyLong_NumBits(v); |
142 | size_t wbits = _PyLong_NumBits(w); |
143 | if (vbits == (size_t)-1 || wbits == (size_t)-1) { |
144 | return NULL; |
145 | } |
146 | if (vbits + wbits > MAX_INT_SIZE) { |
147 | return NULL; |
148 | } |
149 | } |
150 | else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) { |
151 | Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) : |
152 | PySet_GET_SIZE(w); |
153 | if (size) { |
154 | long n = PyLong_AsLong(v); |
155 | if (n < 0 || n > MAX_COLLECTION_SIZE / size) { |
156 | return NULL; |
157 | } |
158 | if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) { |
159 | return NULL; |
160 | } |
161 | } |
162 | } |
163 | else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) { |
164 | Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) : |
165 | PyBytes_GET_SIZE(w); |
166 | if (size) { |
167 | long n = PyLong_AsLong(v); |
168 | if (n < 0 || n > MAX_STR_SIZE / size) { |
169 | return NULL; |
170 | } |
171 | } |
172 | } |
173 | else if (PyLong_Check(w) && |
174 | (PyTuple_Check(v) || PyFrozenSet_Check(v) || |
175 | PyUnicode_Check(v) || PyBytes_Check(v))) |
176 | { |
177 | return safe_multiply(w, v); |
178 | } |
179 | |
180 | return PyNumber_Multiply(v, w); |
181 | } |
182 | |
183 | static PyObject * |
184 | safe_power(PyObject *v, PyObject *w) |
185 | { |
186 | if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) { |
187 | size_t vbits = _PyLong_NumBits(v); |
188 | size_t wbits = PyLong_AsSize_t(w); |
189 | if (vbits == (size_t)-1 || wbits == (size_t)-1) { |
190 | return NULL; |
191 | } |
192 | if (vbits > MAX_INT_SIZE / wbits) { |
193 | return NULL; |
194 | } |
195 | } |
196 | |
197 | return PyNumber_Power(v, w, Py_None); |
198 | } |
199 | |
200 | static PyObject * |
201 | safe_lshift(PyObject *v, PyObject *w) |
202 | { |
203 | if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) { |
204 | size_t vbits = _PyLong_NumBits(v); |
205 | size_t wbits = PyLong_AsSize_t(w); |
206 | if (vbits == (size_t)-1 || wbits == (size_t)-1) { |
207 | return NULL; |
208 | } |
209 | if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) { |
210 | return NULL; |
211 | } |
212 | } |
213 | |
214 | return PyNumber_Lshift(v, w); |
215 | } |
216 | |
217 | static PyObject * |
218 | safe_mod(PyObject *v, PyObject *w) |
219 | { |
220 | if (PyUnicode_Check(v) || PyBytes_Check(v)) { |
221 | return NULL; |
222 | } |
223 | |
224 | return PyNumber_Remainder(v, w); |
225 | } |
226 | |
227 | static int |
228 | fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) |
229 | { |
230 | expr_ty lhs, rhs; |
231 | lhs = node->v.BinOp.left; |
232 | rhs = node->v.BinOp.right; |
233 | if (lhs->kind != Constant_kind || rhs->kind != Constant_kind) { |
234 | return 1; |
235 | } |
236 | |
237 | PyObject *lv = lhs->v.Constant.value; |
238 | PyObject *rv = rhs->v.Constant.value; |
239 | PyObject *newval = NULL; |
240 | |
241 | switch (node->v.BinOp.op) { |
242 | case Add: |
243 | newval = PyNumber_Add(lv, rv); |
244 | break; |
245 | case Sub: |
246 | newval = PyNumber_Subtract(lv, rv); |
247 | break; |
248 | case Mult: |
249 | newval = safe_multiply(lv, rv); |
250 | break; |
251 | case Div: |
252 | newval = PyNumber_TrueDivide(lv, rv); |
253 | break; |
254 | case FloorDiv: |
255 | newval = PyNumber_FloorDivide(lv, rv); |
256 | break; |
257 | case Mod: |
258 | newval = safe_mod(lv, rv); |
259 | break; |
260 | case Pow: |
261 | newval = safe_power(lv, rv); |
262 | break; |
263 | case LShift: |
264 | newval = safe_lshift(lv, rv); |
265 | break; |
266 | case RShift: |
267 | newval = PyNumber_Rshift(lv, rv); |
268 | break; |
269 | case BitOr: |
270 | newval = PyNumber_Or(lv, rv); |
271 | break; |
272 | case BitXor: |
273 | newval = PyNumber_Xor(lv, rv); |
274 | break; |
275 | case BitAnd: |
276 | newval = PyNumber_And(lv, rv); |
277 | break; |
278 | // No builtin constants implement the following operators |
279 | case MatMult: |
280 | return 1; |
281 | // No default case, so the compiler will emit a warning if new binary |
282 | // operators are added without being handled here |
283 | } |
284 | |
285 | return make_const(node, newval, arena); |
286 | } |
287 | |
288 | static PyObject* |
289 | make_const_tuple(asdl_expr_seq *elts) |
290 | { |
291 | for (int i = 0; i < asdl_seq_LEN(elts); i++) { |
292 | expr_ty e = (expr_ty)asdl_seq_GET(elts, i); |
293 | if (e->kind != Constant_kind) { |
294 | return NULL; |
295 | } |
296 | } |
297 | |
298 | PyObject *newval = PyTuple_New(asdl_seq_LEN(elts)); |
299 | if (newval == NULL) { |
300 | return NULL; |
301 | } |
302 | |
303 | for (int i = 0; i < asdl_seq_LEN(elts); i++) { |
304 | expr_ty e = (expr_ty)asdl_seq_GET(elts, i); |
305 | PyObject *v = e->v.Constant.value; |
306 | Py_INCREF(v); |
307 | PyTuple_SET_ITEM(newval, i, v); |
308 | } |
309 | return newval; |
310 | } |
311 | |
312 | static int |
313 | fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) |
314 | { |
315 | PyObject *newval; |
316 | |
317 | if (node->v.Tuple.ctx != Load) |
318 | return 1; |
319 | |
320 | newval = make_const_tuple(node->v.Tuple.elts); |
321 | return make_const(node, newval, arena); |
322 | } |
323 | |
324 | static int |
325 | fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) |
326 | { |
327 | PyObject *newval; |
328 | expr_ty arg, idx; |
329 | |
330 | arg = node->v.Subscript.value; |
331 | idx = node->v.Subscript.slice; |
332 | if (node->v.Subscript.ctx != Load || |
333 | arg->kind != Constant_kind || |
334 | idx->kind != Constant_kind) |
335 | { |
336 | return 1; |
337 | } |
338 | |
339 | newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value); |
340 | return make_const(node, newval, arena); |
341 | } |
342 | |
343 | /* Change literal list or set of constants into constant |
344 | tuple or frozenset respectively. Change literal list of |
345 | non-constants into tuple. |
346 | Used for right operand of "in" and "not in" tests and for iterable |
347 | in "for" loop and comprehensions. |
348 | */ |
349 | static int |
350 | fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state) |
351 | { |
352 | PyObject *newval; |
353 | if (arg->kind == List_kind) { |
354 | /* First change a list into tuple. */ |
355 | asdl_expr_seq *elts = arg->v.List.elts; |
356 | Py_ssize_t n = asdl_seq_LEN(elts); |
357 | for (Py_ssize_t i = 0; i < n; i++) { |
358 | expr_ty e = (expr_ty)asdl_seq_GET(elts, i); |
359 | if (e->kind == Starred_kind) { |
360 | return 1; |
361 | } |
362 | } |
363 | expr_context_ty ctx = arg->v.List.ctx; |
364 | arg->kind = Tuple_kind; |
365 | arg->v.Tuple.elts = elts; |
366 | arg->v.Tuple.ctx = ctx; |
367 | /* Try to create a constant tuple. */ |
368 | newval = make_const_tuple(elts); |
369 | } |
370 | else if (arg->kind == Set_kind) { |
371 | newval = make_const_tuple(arg->v.Set.elts); |
372 | if (newval) { |
373 | Py_SETREF(newval, PyFrozenSet_New(newval)); |
374 | } |
375 | } |
376 | else { |
377 | return 1; |
378 | } |
379 | return make_const(arg, newval, arena); |
380 | } |
381 | |
382 | static int |
383 | fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) |
384 | { |
385 | asdl_int_seq *ops; |
386 | asdl_expr_seq *args; |
387 | Py_ssize_t i; |
388 | |
389 | ops = node->v.Compare.ops; |
390 | args = node->v.Compare.comparators; |
391 | /* TODO: optimize cases with literal arguments. */ |
392 | /* Change literal list or set in 'in' or 'not in' into |
393 | tuple or frozenset respectively. */ |
394 | i = asdl_seq_LEN(ops) - 1; |
395 | int op = asdl_seq_GET(ops, i); |
396 | if (op == In || op == NotIn) { |
397 | if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) { |
398 | return 0; |
399 | } |
400 | } |
401 | return 1; |
402 | } |
403 | |
404 | static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
405 | static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
406 | static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
407 | static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
408 | static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
409 | static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
410 | static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
411 | static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
412 | static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
413 | static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
414 | static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); |
415 | |
416 | #define CALL(FUNC, TYPE, ARG) \ |
417 | if (!FUNC((ARG), ctx_, state)) \ |
418 | return 0; |
419 | |
420 | #define CALL_OPT(FUNC, TYPE, ARG) \ |
421 | if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \ |
422 | return 0; |
423 | |
424 | #define CALL_SEQ(FUNC, TYPE, ARG) { \ |
425 | int i; \ |
426 | asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \ |
427 | for (i = 0; i < asdl_seq_LEN(seq); i++) { \ |
428 | TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \ |
429 | if (elt != NULL && !FUNC(elt, ctx_, state)) \ |
430 | return 0; \ |
431 | } \ |
432 | } |
433 | |
434 | #define CALL_INT_SEQ(FUNC, TYPE, ARG) { \ |
435 | int i; \ |
436 | asdl_int_seq *seq = (ARG); /* avoid variable capture */ \ |
437 | for (i = 0; i < asdl_seq_LEN(seq); i++) { \ |
438 | TYPE elt = (TYPE)asdl_seq_GET(seq, i); \ |
439 | if (!FUNC(elt, ctx_, state)) \ |
440 | return 0; \ |
441 | } \ |
442 | } |
443 | |
444 | static int |
445 | astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state) |
446 | { |
447 | int docstring = _PyAST_GetDocString(stmts) != NULL; |
448 | CALL_SEQ(astfold_stmt, stmt, stmts); |
449 | if (!docstring && _PyAST_GetDocString(stmts) != NULL) { |
450 | stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0); |
451 | asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_); |
452 | if (!values) { |
453 | return 0; |
454 | } |
455 | asdl_seq_SET(values, 0, st->v.Expr.value); |
456 | expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset, |
457 | st->end_lineno, st->end_col_offset, |
458 | ctx_); |
459 | if (!expr) { |
460 | return 0; |
461 | } |
462 | st->v.Expr.value = expr; |
463 | } |
464 | return 1; |
465 | } |
466 | |
467 | static int |
468 | astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
469 | { |
470 | switch (node_->kind) { |
471 | case Module_kind: |
472 | CALL(astfold_body, asdl_seq, node_->v.Module.body); |
473 | break; |
474 | case Interactive_kind: |
475 | CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body); |
476 | break; |
477 | case Expression_kind: |
478 | CALL(astfold_expr, expr_ty, node_->v.Expression.body); |
479 | break; |
480 | // The following top level nodes don't participate in constant folding |
481 | case FunctionType_kind: |
482 | break; |
483 | // No default case, so the compiler will emit a warning if new top level |
484 | // compilation nodes are added without being handled here |
485 | } |
486 | return 1; |
487 | } |
488 | |
489 | static int |
490 | astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
491 | { |
492 | if (++state->recursion_depth > state->recursion_limit) { |
493 | PyErr_SetString(PyExc_RecursionError, |
494 | "maximum recursion depth exceeded during compilation" ); |
495 | return 0; |
496 | } |
497 | switch (node_->kind) { |
498 | case BoolOp_kind: |
499 | CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values); |
500 | break; |
501 | case BinOp_kind: |
502 | CALL(astfold_expr, expr_ty, node_->v.BinOp.left); |
503 | CALL(astfold_expr, expr_ty, node_->v.BinOp.right); |
504 | CALL(fold_binop, expr_ty, node_); |
505 | break; |
506 | case UnaryOp_kind: |
507 | CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand); |
508 | CALL(fold_unaryop, expr_ty, node_); |
509 | break; |
510 | case Lambda_kind: |
511 | CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args); |
512 | CALL(astfold_expr, expr_ty, node_->v.Lambda.body); |
513 | break; |
514 | case IfExp_kind: |
515 | CALL(astfold_expr, expr_ty, node_->v.IfExp.test); |
516 | CALL(astfold_expr, expr_ty, node_->v.IfExp.body); |
517 | CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse); |
518 | break; |
519 | case Dict_kind: |
520 | CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys); |
521 | CALL_SEQ(astfold_expr, expr, node_->v.Dict.values); |
522 | break; |
523 | case Set_kind: |
524 | CALL_SEQ(astfold_expr, expr, node_->v.Set.elts); |
525 | break; |
526 | case ListComp_kind: |
527 | CALL(astfold_expr, expr_ty, node_->v.ListComp.elt); |
528 | CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators); |
529 | break; |
530 | case SetComp_kind: |
531 | CALL(astfold_expr, expr_ty, node_->v.SetComp.elt); |
532 | CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators); |
533 | break; |
534 | case DictComp_kind: |
535 | CALL(astfold_expr, expr_ty, node_->v.DictComp.key); |
536 | CALL(astfold_expr, expr_ty, node_->v.DictComp.value); |
537 | CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators); |
538 | break; |
539 | case GeneratorExp_kind: |
540 | CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt); |
541 | CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators); |
542 | break; |
543 | case Await_kind: |
544 | CALL(astfold_expr, expr_ty, node_->v.Await.value); |
545 | break; |
546 | case Yield_kind: |
547 | CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value); |
548 | break; |
549 | case YieldFrom_kind: |
550 | CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value); |
551 | break; |
552 | case Compare_kind: |
553 | CALL(astfold_expr, expr_ty, node_->v.Compare.left); |
554 | CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators); |
555 | CALL(fold_compare, expr_ty, node_); |
556 | break; |
557 | case Call_kind: |
558 | CALL(astfold_expr, expr_ty, node_->v.Call.func); |
559 | CALL_SEQ(astfold_expr, expr, node_->v.Call.args); |
560 | CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords); |
561 | break; |
562 | case FormattedValue_kind: |
563 | CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value); |
564 | CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec); |
565 | break; |
566 | case JoinedStr_kind: |
567 | CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values); |
568 | break; |
569 | case Attribute_kind: |
570 | CALL(astfold_expr, expr_ty, node_->v.Attribute.value); |
571 | break; |
572 | case Subscript_kind: |
573 | CALL(astfold_expr, expr_ty, node_->v.Subscript.value); |
574 | CALL(astfold_expr, expr_ty, node_->v.Subscript.slice); |
575 | CALL(fold_subscr, expr_ty, node_); |
576 | break; |
577 | case Starred_kind: |
578 | CALL(astfold_expr, expr_ty, node_->v.Starred.value); |
579 | break; |
580 | case Slice_kind: |
581 | CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower); |
582 | CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper); |
583 | CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step); |
584 | break; |
585 | case List_kind: |
586 | CALL_SEQ(astfold_expr, expr, node_->v.List.elts); |
587 | break; |
588 | case Tuple_kind: |
589 | CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts); |
590 | CALL(fold_tuple, expr_ty, node_); |
591 | break; |
592 | case Name_kind: |
593 | if (node_->v.Name.ctx == Load && |
594 | _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__" )) { |
595 | state->recursion_depth--; |
596 | return make_const(node_, PyBool_FromLong(!state->optimize), ctx_); |
597 | } |
598 | break; |
599 | case NamedExpr_kind: |
600 | CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value); |
601 | break; |
602 | case Constant_kind: |
603 | // Already a constant, nothing further to do |
604 | break; |
605 | // No default case, so the compiler will emit a warning if new expression |
606 | // kinds are added without being handled here |
607 | } |
608 | state->recursion_depth--; |
609 | return 1; |
610 | } |
611 | |
612 | static int |
613 | astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
614 | { |
615 | CALL(astfold_expr, expr_ty, node_->value); |
616 | return 1; |
617 | } |
618 | |
619 | static int |
620 | astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
621 | { |
622 | CALL(astfold_expr, expr_ty, node_->target); |
623 | CALL(astfold_expr, expr_ty, node_->iter); |
624 | CALL_SEQ(astfold_expr, expr, node_->ifs); |
625 | |
626 | CALL(fold_iter, expr_ty, node_->iter); |
627 | return 1; |
628 | } |
629 | |
630 | static int |
631 | astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
632 | { |
633 | CALL_SEQ(astfold_arg, arg, node_->posonlyargs); |
634 | CALL_SEQ(astfold_arg, arg, node_->args); |
635 | CALL_OPT(astfold_arg, arg_ty, node_->vararg); |
636 | CALL_SEQ(astfold_arg, arg, node_->kwonlyargs); |
637 | CALL_SEQ(astfold_expr, expr, node_->kw_defaults); |
638 | CALL_OPT(astfold_arg, arg_ty, node_->kwarg); |
639 | CALL_SEQ(astfold_expr, expr, node_->defaults); |
640 | return 1; |
641 | } |
642 | |
643 | static int |
644 | astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
645 | { |
646 | if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) { |
647 | CALL_OPT(astfold_expr, expr_ty, node_->annotation); |
648 | } |
649 | return 1; |
650 | } |
651 | |
652 | static int |
653 | astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
654 | { |
655 | if (++state->recursion_depth > state->recursion_limit) { |
656 | PyErr_SetString(PyExc_RecursionError, |
657 | "maximum recursion depth exceeded during compilation" ); |
658 | return 0; |
659 | } |
660 | switch (node_->kind) { |
661 | case FunctionDef_kind: |
662 | CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args); |
663 | CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body); |
664 | CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list); |
665 | if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) { |
666 | CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns); |
667 | } |
668 | break; |
669 | case AsyncFunctionDef_kind: |
670 | CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args); |
671 | CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body); |
672 | CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list); |
673 | if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) { |
674 | CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns); |
675 | } |
676 | break; |
677 | case ClassDef_kind: |
678 | CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases); |
679 | CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords); |
680 | CALL(astfold_body, asdl_seq, node_->v.ClassDef.body); |
681 | CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list); |
682 | break; |
683 | case Return_kind: |
684 | CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value); |
685 | break; |
686 | case Delete_kind: |
687 | CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets); |
688 | break; |
689 | case Assign_kind: |
690 | CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets); |
691 | CALL(astfold_expr, expr_ty, node_->v.Assign.value); |
692 | break; |
693 | case AugAssign_kind: |
694 | CALL(astfold_expr, expr_ty, node_->v.AugAssign.target); |
695 | CALL(astfold_expr, expr_ty, node_->v.AugAssign.value); |
696 | break; |
697 | case AnnAssign_kind: |
698 | CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target); |
699 | if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) { |
700 | CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation); |
701 | } |
702 | CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value); |
703 | break; |
704 | case For_kind: |
705 | CALL(astfold_expr, expr_ty, node_->v.For.target); |
706 | CALL(astfold_expr, expr_ty, node_->v.For.iter); |
707 | CALL_SEQ(astfold_stmt, stmt, node_->v.For.body); |
708 | CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse); |
709 | |
710 | CALL(fold_iter, expr_ty, node_->v.For.iter); |
711 | break; |
712 | case AsyncFor_kind: |
713 | CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target); |
714 | CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter); |
715 | CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body); |
716 | CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse); |
717 | break; |
718 | case While_kind: |
719 | CALL(astfold_expr, expr_ty, node_->v.While.test); |
720 | CALL_SEQ(astfold_stmt, stmt, node_->v.While.body); |
721 | CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse); |
722 | break; |
723 | case If_kind: |
724 | CALL(astfold_expr, expr_ty, node_->v.If.test); |
725 | CALL_SEQ(astfold_stmt, stmt, node_->v.If.body); |
726 | CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse); |
727 | break; |
728 | case With_kind: |
729 | CALL_SEQ(astfold_withitem, withitem, node_->v.With.items); |
730 | CALL_SEQ(astfold_stmt, stmt, node_->v.With.body); |
731 | break; |
732 | case AsyncWith_kind: |
733 | CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items); |
734 | CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body); |
735 | break; |
736 | case Raise_kind: |
737 | CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc); |
738 | CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause); |
739 | break; |
740 | case Try_kind: |
741 | CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body); |
742 | CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers); |
743 | CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse); |
744 | CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody); |
745 | break; |
746 | case Assert_kind: |
747 | CALL(astfold_expr, expr_ty, node_->v.Assert.test); |
748 | CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg); |
749 | break; |
750 | case Expr_kind: |
751 | CALL(astfold_expr, expr_ty, node_->v.Expr.value); |
752 | break; |
753 | case Match_kind: |
754 | CALL(astfold_expr, expr_ty, node_->v.Match.subject); |
755 | CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases); |
756 | break; |
757 | // The following statements don't contain any subexpressions to be folded |
758 | case Import_kind: |
759 | case ImportFrom_kind: |
760 | case Global_kind: |
761 | case Nonlocal_kind: |
762 | case Pass_kind: |
763 | case Break_kind: |
764 | case Continue_kind: |
765 | break; |
766 | // No default case, so the compiler will emit a warning if new statement |
767 | // kinds are added without being handled here |
768 | } |
769 | state->recursion_depth--; |
770 | return 1; |
771 | } |
772 | |
773 | static int |
774 | astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
775 | { |
776 | switch (node_->kind) { |
777 | case ExceptHandler_kind: |
778 | CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type); |
779 | CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body); |
780 | break; |
781 | // No default case, so the compiler will emit a warning if new handler |
782 | // kinds are added without being handled here |
783 | } |
784 | return 1; |
785 | } |
786 | |
787 | static int |
788 | astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
789 | { |
790 | CALL(astfold_expr, expr_ty, node_->context_expr); |
791 | CALL_OPT(astfold_expr, expr_ty, node_->optional_vars); |
792 | return 1; |
793 | } |
794 | |
795 | static int |
796 | astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
797 | { |
798 | // Currently, this is really only used to form complex/negative numeric |
799 | // constants in MatchValue and MatchMapping nodes |
800 | // We still recurse into all subexpressions and subpatterns anyway |
801 | if (++state->recursion_depth > state->recursion_limit) { |
802 | PyErr_SetString(PyExc_RecursionError, |
803 | "maximum recursion depth exceeded during compilation" ); |
804 | return 0; |
805 | } |
806 | switch (node_->kind) { |
807 | case MatchValue_kind: |
808 | CALL(astfold_expr, expr_ty, node_->v.MatchValue.value); |
809 | break; |
810 | case MatchSingleton_kind: |
811 | break; |
812 | case MatchSequence_kind: |
813 | CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns); |
814 | break; |
815 | case MatchMapping_kind: |
816 | CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys); |
817 | CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns); |
818 | break; |
819 | case MatchClass_kind: |
820 | CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls); |
821 | CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns); |
822 | CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns); |
823 | break; |
824 | case MatchStar_kind: |
825 | break; |
826 | case MatchAs_kind: |
827 | if (node_->v.MatchAs.pattern) { |
828 | CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern); |
829 | } |
830 | break; |
831 | case MatchOr_kind: |
832 | CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns); |
833 | break; |
834 | // No default case, so the compiler will emit a warning if new pattern |
835 | // kinds are added without being handled here |
836 | } |
837 | state->recursion_depth--; |
838 | return 1; |
839 | } |
840 | |
841 | static int |
842 | astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) |
843 | { |
844 | CALL(astfold_pattern, expr_ty, node_->pattern); |
845 | CALL_OPT(astfold_expr, expr_ty, node_->guard); |
846 | CALL_SEQ(astfold_stmt, stmt, node_->body); |
847 | return 1; |
848 | } |
849 | |
850 | #undef CALL |
851 | #undef CALL_OPT |
852 | #undef CALL_SEQ |
853 | #undef CALL_INT_SEQ |
854 | |
855 | /* See comments in symtable.c. */ |
856 | #define COMPILER_STACK_FRAME_SCALE 3 |
857 | |
858 | int |
859 | _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state) |
860 | { |
861 | PyThreadState *tstate; |
862 | int recursion_limit = Py_GetRecursionLimit(); |
863 | int starting_recursion_depth; |
864 | |
865 | /* Setup recursion depth check counters */ |
866 | tstate = _PyThreadState_GET(); |
867 | if (!tstate) { |
868 | return 0; |
869 | } |
870 | /* Be careful here to prevent overflow. */ |
871 | starting_recursion_depth = (tstate->recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? |
872 | tstate->recursion_depth * COMPILER_STACK_FRAME_SCALE : tstate->recursion_depth; |
873 | state->recursion_depth = starting_recursion_depth; |
874 | state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ? |
875 | recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit; |
876 | |
877 | int ret = astfold_mod(mod, arena, state); |
878 | assert(ret || PyErr_Occurred()); |
879 | |
880 | /* Check that the recursion depth counting balanced correctly */ |
881 | if (ret && state->recursion_depth != starting_recursion_depth) { |
882 | PyErr_Format(PyExc_SystemError, |
883 | "AST optimizer recursion depth mismatch (before=%d, after=%d)" , |
884 | starting_recursion_depth, state->recursion_depth); |
885 | return 0; |
886 | } |
887 | |
888 | return ret; |
889 | } |
890 | |