1/* AST Optimizer */
2#include "Python.h"
3#include "pycore_ast.h" // _PyAST_GetDocString()
4#include "pycore_compile.h" // _PyASTOptimizeState
5#include "pycore_pystate.h" // _PyThreadState_GET()
6
7
8static int
9make_const(expr_ty node, PyObject *val, PyArena *arena)
10{
11 // Even if no new value was calculated, make_const may still
12 // need to clear an error (e.g. for division by zero)
13 if (val == NULL) {
14 if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
15 return 0;
16 }
17 PyErr_Clear();
18 return 1;
19 }
20 if (_PyArena_AddPyObject(arena, val) < 0) {
21 Py_DECREF(val);
22 return 0;
23 }
24 node->kind = Constant_kind;
25 node->v.Constant.kind = NULL;
26 node->v.Constant.value = val;
27 return 1;
28}
29
30#define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
31
32static PyObject*
33unary_not(PyObject *v)
34{
35 int r = PyObject_IsTrue(v);
36 if (r < 0)
37 return NULL;
38 return PyBool_FromLong(!r);
39}
40
41static int
42fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
43{
44 expr_ty arg = node->v.UnaryOp.operand;
45
46 if (arg->kind != Constant_kind) {
47 /* Fold not into comparison */
48 if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
49 asdl_seq_LEN(arg->v.Compare.ops) == 1) {
50 /* Eq and NotEq are often implemented in terms of one another, so
51 folding not (self == other) into self != other breaks implementation
52 of !=. Detecting such cases doesn't seem worthwhile.
53 Python uses </> for 'is subset'/'is superset' operations on sets.
54 They don't satisfy not folding laws. */
55 cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
56 switch (op) {
57 case Is:
58 op = IsNot;
59 break;
60 case IsNot:
61 op = Is;
62 break;
63 case In:
64 op = NotIn;
65 break;
66 case NotIn:
67 op = In;
68 break;
69 // The remaining comparison operators can't be safely inverted
70 case Eq:
71 case NotEq:
72 case Lt:
73 case LtE:
74 case Gt:
75 case GtE:
76 op = 0; // The AST enums leave "0" free as an "unused" marker
77 break;
78 // No default case, so the compiler will emit a warning if new
79 // comparison operators are added without being handled here
80 }
81 if (op) {
82 asdl_seq_SET(arg->v.Compare.ops, 0, op);
83 COPY_NODE(node, arg);
84 return 1;
85 }
86 }
87 return 1;
88 }
89
90 typedef PyObject *(*unary_op)(PyObject*);
91 static const unary_op ops[] = {
92 [Invert] = PyNumber_Invert,
93 [Not] = unary_not,
94 [UAdd] = PyNumber_Positive,
95 [USub] = PyNumber_Negative,
96 };
97 PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
98 return make_const(node, newval, arena);
99}
100
101/* Check whether a collection doesn't containing too much items (including
102 subcollections). This protects from creating a constant that needs
103 too much time for calculating a hash.
104 "limit" is the maximal number of items.
105 Returns the negative number if the total number of items exceeds the
106 limit. Otherwise returns the limit minus the total number of items.
107*/
108
109static Py_ssize_t
110check_complexity(PyObject *obj, Py_ssize_t limit)
111{
112 if (PyTuple_Check(obj)) {
113 Py_ssize_t i;
114 limit -= PyTuple_GET_SIZE(obj);
115 for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
116 limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
117 }
118 return limit;
119 }
120 else if (PyFrozenSet_Check(obj)) {
121 Py_ssize_t i = 0;
122 PyObject *item;
123 Py_hash_t hash;
124 limit -= PySet_GET_SIZE(obj);
125 while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
126 limit = check_complexity(item, limit);
127 }
128 }
129 return limit;
130}
131
132#define MAX_INT_SIZE 128 /* bits */
133#define MAX_COLLECTION_SIZE 256 /* items */
134#define MAX_STR_SIZE 4096 /* characters */
135#define MAX_TOTAL_ITEMS 1024 /* including nested collections */
136
137static PyObject *
138safe_multiply(PyObject *v, PyObject *w)
139{
140 if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
141 size_t vbits = _PyLong_NumBits(v);
142 size_t wbits = _PyLong_NumBits(w);
143 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
144 return NULL;
145 }
146 if (vbits + wbits > MAX_INT_SIZE) {
147 return NULL;
148 }
149 }
150 else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
151 Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
152 PySet_GET_SIZE(w);
153 if (size) {
154 long n = PyLong_AsLong(v);
155 if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
156 return NULL;
157 }
158 if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
159 return NULL;
160 }
161 }
162 }
163 else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
164 Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
165 PyBytes_GET_SIZE(w);
166 if (size) {
167 long n = PyLong_AsLong(v);
168 if (n < 0 || n > MAX_STR_SIZE / size) {
169 return NULL;
170 }
171 }
172 }
173 else if (PyLong_Check(w) &&
174 (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
175 PyUnicode_Check(v) || PyBytes_Check(v)))
176 {
177 return safe_multiply(w, v);
178 }
179
180 return PyNumber_Multiply(v, w);
181}
182
183static PyObject *
184safe_power(PyObject *v, PyObject *w)
185{
186 if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w) > 0) {
187 size_t vbits = _PyLong_NumBits(v);
188 size_t wbits = PyLong_AsSize_t(w);
189 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
190 return NULL;
191 }
192 if (vbits > MAX_INT_SIZE / wbits) {
193 return NULL;
194 }
195 }
196
197 return PyNumber_Power(v, w, Py_None);
198}
199
200static PyObject *
201safe_lshift(PyObject *v, PyObject *w)
202{
203 if (PyLong_Check(v) && PyLong_Check(w) && Py_SIZE(v) && Py_SIZE(w)) {
204 size_t vbits = _PyLong_NumBits(v);
205 size_t wbits = PyLong_AsSize_t(w);
206 if (vbits == (size_t)-1 || wbits == (size_t)-1) {
207 return NULL;
208 }
209 if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
210 return NULL;
211 }
212 }
213
214 return PyNumber_Lshift(v, w);
215}
216
217static PyObject *
218safe_mod(PyObject *v, PyObject *w)
219{
220 if (PyUnicode_Check(v) || PyBytes_Check(v)) {
221 return NULL;
222 }
223
224 return PyNumber_Remainder(v, w);
225}
226
227static int
228fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
229{
230 expr_ty lhs, rhs;
231 lhs = node->v.BinOp.left;
232 rhs = node->v.BinOp.right;
233 if (lhs->kind != Constant_kind || rhs->kind != Constant_kind) {
234 return 1;
235 }
236
237 PyObject *lv = lhs->v.Constant.value;
238 PyObject *rv = rhs->v.Constant.value;
239 PyObject *newval = NULL;
240
241 switch (node->v.BinOp.op) {
242 case Add:
243 newval = PyNumber_Add(lv, rv);
244 break;
245 case Sub:
246 newval = PyNumber_Subtract(lv, rv);
247 break;
248 case Mult:
249 newval = safe_multiply(lv, rv);
250 break;
251 case Div:
252 newval = PyNumber_TrueDivide(lv, rv);
253 break;
254 case FloorDiv:
255 newval = PyNumber_FloorDivide(lv, rv);
256 break;
257 case Mod:
258 newval = safe_mod(lv, rv);
259 break;
260 case Pow:
261 newval = safe_power(lv, rv);
262 break;
263 case LShift:
264 newval = safe_lshift(lv, rv);
265 break;
266 case RShift:
267 newval = PyNumber_Rshift(lv, rv);
268 break;
269 case BitOr:
270 newval = PyNumber_Or(lv, rv);
271 break;
272 case BitXor:
273 newval = PyNumber_Xor(lv, rv);
274 break;
275 case BitAnd:
276 newval = PyNumber_And(lv, rv);
277 break;
278 // No builtin constants implement the following operators
279 case MatMult:
280 return 1;
281 // No default case, so the compiler will emit a warning if new binary
282 // operators are added without being handled here
283 }
284
285 return make_const(node, newval, arena);
286}
287
288static PyObject*
289make_const_tuple(asdl_expr_seq *elts)
290{
291 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
292 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
293 if (e->kind != Constant_kind) {
294 return NULL;
295 }
296 }
297
298 PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
299 if (newval == NULL) {
300 return NULL;
301 }
302
303 for (int i = 0; i < asdl_seq_LEN(elts); i++) {
304 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
305 PyObject *v = e->v.Constant.value;
306 Py_INCREF(v);
307 PyTuple_SET_ITEM(newval, i, v);
308 }
309 return newval;
310}
311
312static int
313fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
314{
315 PyObject *newval;
316
317 if (node->v.Tuple.ctx != Load)
318 return 1;
319
320 newval = make_const_tuple(node->v.Tuple.elts);
321 return make_const(node, newval, arena);
322}
323
324static int
325fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
326{
327 PyObject *newval;
328 expr_ty arg, idx;
329
330 arg = node->v.Subscript.value;
331 idx = node->v.Subscript.slice;
332 if (node->v.Subscript.ctx != Load ||
333 arg->kind != Constant_kind ||
334 idx->kind != Constant_kind)
335 {
336 return 1;
337 }
338
339 newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
340 return make_const(node, newval, arena);
341}
342
343/* Change literal list or set of constants into constant
344 tuple or frozenset respectively. Change literal list of
345 non-constants into tuple.
346 Used for right operand of "in" and "not in" tests and for iterable
347 in "for" loop and comprehensions.
348*/
349static int
350fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
351{
352 PyObject *newval;
353 if (arg->kind == List_kind) {
354 /* First change a list into tuple. */
355 asdl_expr_seq *elts = arg->v.List.elts;
356 Py_ssize_t n = asdl_seq_LEN(elts);
357 for (Py_ssize_t i = 0; i < n; i++) {
358 expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
359 if (e->kind == Starred_kind) {
360 return 1;
361 }
362 }
363 expr_context_ty ctx = arg->v.List.ctx;
364 arg->kind = Tuple_kind;
365 arg->v.Tuple.elts = elts;
366 arg->v.Tuple.ctx = ctx;
367 /* Try to create a constant tuple. */
368 newval = make_const_tuple(elts);
369 }
370 else if (arg->kind == Set_kind) {
371 newval = make_const_tuple(arg->v.Set.elts);
372 if (newval) {
373 Py_SETREF(newval, PyFrozenSet_New(newval));
374 }
375 }
376 else {
377 return 1;
378 }
379 return make_const(arg, newval, arena);
380}
381
382static int
383fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
384{
385 asdl_int_seq *ops;
386 asdl_expr_seq *args;
387 Py_ssize_t i;
388
389 ops = node->v.Compare.ops;
390 args = node->v.Compare.comparators;
391 /* TODO: optimize cases with literal arguments. */
392 /* Change literal list or set in 'in' or 'not in' into
393 tuple or frozenset respectively. */
394 i = asdl_seq_LEN(ops) - 1;
395 int op = asdl_seq_GET(ops, i);
396 if (op == In || op == NotIn) {
397 if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
398 return 0;
399 }
400 }
401 return 1;
402}
403
404static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
405static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
406static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
407static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
408static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
409static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
410static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
411static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
412static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
413static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
414static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
415
416#define CALL(FUNC, TYPE, ARG) \
417 if (!FUNC((ARG), ctx_, state)) \
418 return 0;
419
420#define CALL_OPT(FUNC, TYPE, ARG) \
421 if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
422 return 0;
423
424#define CALL_SEQ(FUNC, TYPE, ARG) { \
425 int i; \
426 asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
427 for (i = 0; i < asdl_seq_LEN(seq); i++) { \
428 TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
429 if (elt != NULL && !FUNC(elt, ctx_, state)) \
430 return 0; \
431 } \
432}
433
434#define CALL_INT_SEQ(FUNC, TYPE, ARG) { \
435 int i; \
436 asdl_int_seq *seq = (ARG); /* avoid variable capture */ \
437 for (i = 0; i < asdl_seq_LEN(seq); i++) { \
438 TYPE elt = (TYPE)asdl_seq_GET(seq, i); \
439 if (!FUNC(elt, ctx_, state)) \
440 return 0; \
441 } \
442}
443
444static int
445astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
446{
447 int docstring = _PyAST_GetDocString(stmts) != NULL;
448 CALL_SEQ(astfold_stmt, stmt, stmts);
449 if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
450 stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
451 asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
452 if (!values) {
453 return 0;
454 }
455 asdl_seq_SET(values, 0, st->v.Expr.value);
456 expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
457 st->end_lineno, st->end_col_offset,
458 ctx_);
459 if (!expr) {
460 return 0;
461 }
462 st->v.Expr.value = expr;
463 }
464 return 1;
465}
466
467static int
468astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
469{
470 switch (node_->kind) {
471 case Module_kind:
472 CALL(astfold_body, asdl_seq, node_->v.Module.body);
473 break;
474 case Interactive_kind:
475 CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
476 break;
477 case Expression_kind:
478 CALL(astfold_expr, expr_ty, node_->v.Expression.body);
479 break;
480 // The following top level nodes don't participate in constant folding
481 case FunctionType_kind:
482 break;
483 // No default case, so the compiler will emit a warning if new top level
484 // compilation nodes are added without being handled here
485 }
486 return 1;
487}
488
489static int
490astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
491{
492 if (++state->recursion_depth > state->recursion_limit) {
493 PyErr_SetString(PyExc_RecursionError,
494 "maximum recursion depth exceeded during compilation");
495 return 0;
496 }
497 switch (node_->kind) {
498 case BoolOp_kind:
499 CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
500 break;
501 case BinOp_kind:
502 CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
503 CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
504 CALL(fold_binop, expr_ty, node_);
505 break;
506 case UnaryOp_kind:
507 CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
508 CALL(fold_unaryop, expr_ty, node_);
509 break;
510 case Lambda_kind:
511 CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
512 CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
513 break;
514 case IfExp_kind:
515 CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
516 CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
517 CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
518 break;
519 case Dict_kind:
520 CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
521 CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
522 break;
523 case Set_kind:
524 CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
525 break;
526 case ListComp_kind:
527 CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
528 CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
529 break;
530 case SetComp_kind:
531 CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
532 CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
533 break;
534 case DictComp_kind:
535 CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
536 CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
537 CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
538 break;
539 case GeneratorExp_kind:
540 CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
541 CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
542 break;
543 case Await_kind:
544 CALL(astfold_expr, expr_ty, node_->v.Await.value);
545 break;
546 case Yield_kind:
547 CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
548 break;
549 case YieldFrom_kind:
550 CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
551 break;
552 case Compare_kind:
553 CALL(astfold_expr, expr_ty, node_->v.Compare.left);
554 CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
555 CALL(fold_compare, expr_ty, node_);
556 break;
557 case Call_kind:
558 CALL(astfold_expr, expr_ty, node_->v.Call.func);
559 CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
560 CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
561 break;
562 case FormattedValue_kind:
563 CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
564 CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
565 break;
566 case JoinedStr_kind:
567 CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
568 break;
569 case Attribute_kind:
570 CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
571 break;
572 case Subscript_kind:
573 CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
574 CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
575 CALL(fold_subscr, expr_ty, node_);
576 break;
577 case Starred_kind:
578 CALL(astfold_expr, expr_ty, node_->v.Starred.value);
579 break;
580 case Slice_kind:
581 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
582 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
583 CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
584 break;
585 case List_kind:
586 CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
587 break;
588 case Tuple_kind:
589 CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
590 CALL(fold_tuple, expr_ty, node_);
591 break;
592 case Name_kind:
593 if (node_->v.Name.ctx == Load &&
594 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
595 state->recursion_depth--;
596 return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
597 }
598 break;
599 case NamedExpr_kind:
600 CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
601 break;
602 case Constant_kind:
603 // Already a constant, nothing further to do
604 break;
605 // No default case, so the compiler will emit a warning if new expression
606 // kinds are added without being handled here
607 }
608 state->recursion_depth--;
609 return 1;
610}
611
612static int
613astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
614{
615 CALL(astfold_expr, expr_ty, node_->value);
616 return 1;
617}
618
619static int
620astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
621{
622 CALL(astfold_expr, expr_ty, node_->target);
623 CALL(astfold_expr, expr_ty, node_->iter);
624 CALL_SEQ(astfold_expr, expr, node_->ifs);
625
626 CALL(fold_iter, expr_ty, node_->iter);
627 return 1;
628}
629
630static int
631astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
632{
633 CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
634 CALL_SEQ(astfold_arg, arg, node_->args);
635 CALL_OPT(astfold_arg, arg_ty, node_->vararg);
636 CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
637 CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
638 CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
639 CALL_SEQ(astfold_expr, expr, node_->defaults);
640 return 1;
641}
642
643static int
644astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
645{
646 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
647 CALL_OPT(astfold_expr, expr_ty, node_->annotation);
648 }
649 return 1;
650}
651
652static int
653astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
654{
655 if (++state->recursion_depth > state->recursion_limit) {
656 PyErr_SetString(PyExc_RecursionError,
657 "maximum recursion depth exceeded during compilation");
658 return 0;
659 }
660 switch (node_->kind) {
661 case FunctionDef_kind:
662 CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
663 CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
664 CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
665 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
666 CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
667 }
668 break;
669 case AsyncFunctionDef_kind:
670 CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
671 CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
672 CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
673 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
674 CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
675 }
676 break;
677 case ClassDef_kind:
678 CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
679 CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
680 CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
681 CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
682 break;
683 case Return_kind:
684 CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
685 break;
686 case Delete_kind:
687 CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
688 break;
689 case Assign_kind:
690 CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
691 CALL(astfold_expr, expr_ty, node_->v.Assign.value);
692 break;
693 case AugAssign_kind:
694 CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
695 CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
696 break;
697 case AnnAssign_kind:
698 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
699 if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
700 CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
701 }
702 CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
703 break;
704 case For_kind:
705 CALL(astfold_expr, expr_ty, node_->v.For.target);
706 CALL(astfold_expr, expr_ty, node_->v.For.iter);
707 CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
708 CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
709
710 CALL(fold_iter, expr_ty, node_->v.For.iter);
711 break;
712 case AsyncFor_kind:
713 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
714 CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
715 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
716 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
717 break;
718 case While_kind:
719 CALL(astfold_expr, expr_ty, node_->v.While.test);
720 CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
721 CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
722 break;
723 case If_kind:
724 CALL(astfold_expr, expr_ty, node_->v.If.test);
725 CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
726 CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
727 break;
728 case With_kind:
729 CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
730 CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
731 break;
732 case AsyncWith_kind:
733 CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
734 CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
735 break;
736 case Raise_kind:
737 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
738 CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
739 break;
740 case Try_kind:
741 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
742 CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
743 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
744 CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
745 break;
746 case Assert_kind:
747 CALL(astfold_expr, expr_ty, node_->v.Assert.test);
748 CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
749 break;
750 case Expr_kind:
751 CALL(astfold_expr, expr_ty, node_->v.Expr.value);
752 break;
753 case Match_kind:
754 CALL(astfold_expr, expr_ty, node_->v.Match.subject);
755 CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
756 break;
757 // The following statements don't contain any subexpressions to be folded
758 case Import_kind:
759 case ImportFrom_kind:
760 case Global_kind:
761 case Nonlocal_kind:
762 case Pass_kind:
763 case Break_kind:
764 case Continue_kind:
765 break;
766 // No default case, so the compiler will emit a warning if new statement
767 // kinds are added without being handled here
768 }
769 state->recursion_depth--;
770 return 1;
771}
772
773static int
774astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
775{
776 switch (node_->kind) {
777 case ExceptHandler_kind:
778 CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
779 CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
780 break;
781 // No default case, so the compiler will emit a warning if new handler
782 // kinds are added without being handled here
783 }
784 return 1;
785}
786
787static int
788astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
789{
790 CALL(astfold_expr, expr_ty, node_->context_expr);
791 CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
792 return 1;
793}
794
795static int
796astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
797{
798 // Currently, this is really only used to form complex/negative numeric
799 // constants in MatchValue and MatchMapping nodes
800 // We still recurse into all subexpressions and subpatterns anyway
801 if (++state->recursion_depth > state->recursion_limit) {
802 PyErr_SetString(PyExc_RecursionError,
803 "maximum recursion depth exceeded during compilation");
804 return 0;
805 }
806 switch (node_->kind) {
807 case MatchValue_kind:
808 CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
809 break;
810 case MatchSingleton_kind:
811 break;
812 case MatchSequence_kind:
813 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
814 break;
815 case MatchMapping_kind:
816 CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
817 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
818 break;
819 case MatchClass_kind:
820 CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
821 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
822 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
823 break;
824 case MatchStar_kind:
825 break;
826 case MatchAs_kind:
827 if (node_->v.MatchAs.pattern) {
828 CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
829 }
830 break;
831 case MatchOr_kind:
832 CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
833 break;
834 // No default case, so the compiler will emit a warning if new pattern
835 // kinds are added without being handled here
836 }
837 state->recursion_depth--;
838 return 1;
839}
840
841static int
842astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
843{
844 CALL(astfold_pattern, expr_ty, node_->pattern);
845 CALL_OPT(astfold_expr, expr_ty, node_->guard);
846 CALL_SEQ(astfold_stmt, stmt, node_->body);
847 return 1;
848}
849
850#undef CALL
851#undef CALL_OPT
852#undef CALL_SEQ
853#undef CALL_INT_SEQ
854
855/* See comments in symtable.c. */
856#define COMPILER_STACK_FRAME_SCALE 3
857
858int
859_PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
860{
861 PyThreadState *tstate;
862 int recursion_limit = Py_GetRecursionLimit();
863 int starting_recursion_depth;
864
865 /* Setup recursion depth check counters */
866 tstate = _PyThreadState_GET();
867 if (!tstate) {
868 return 0;
869 }
870 /* Be careful here to prevent overflow. */
871 starting_recursion_depth = (tstate->recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
872 tstate->recursion_depth * COMPILER_STACK_FRAME_SCALE : tstate->recursion_depth;
873 state->recursion_depth = starting_recursion_depth;
874 state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
875 recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
876
877 int ret = astfold_mod(mod, arena, state);
878 assert(ret || PyErr_Occurred());
879
880 /* Check that the recursion depth counting balanced correctly */
881 if (ret && state->recursion_depth != starting_recursion_depth) {
882 PyErr_Format(PyExc_SystemError,
883 "AST optimizer recursion depth mismatch (before=%d, after=%d)",
884 starting_recursion_depth, state->recursion_depth);
885 return 0;
886 }
887
888 return ret;
889}
890