1#define PY_SSIZE_T_CLEAN
2#include <torch/csrc/utils/python_compat.h>
3#include <opcode.h>
4#include <stdbool.h>
5
6// see https://bugs.python.org/issue35886
7#if PY_VERSION_HEX >= 0x03080000
8#define Py_BUILD_CORE
9#include <internal/pycore_pystate.h>
10
11// These headers were added in 3.11
12#if IS_PYTHON_3_11_PLUS
13#include <internal/pycore_frame.h>
14#define NEED_OPCODE_TABLES // To get _PyOpcode_Deopt
15#include <internal/pycore_opcode.h>
16#undef NEED_OPCODE_TABLES
17#endif
18
19#undef Py_BUILD_CORE
20#endif // PY_VERSION_HEX >= 0x03080000
21
22// All the eval APIs change in 3.11 so we need to decide which one to use on the fly
23// https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction
24#if IS_PYTHON_3_11_PLUS
25#define THP_EVAL_API_FRAME_OBJECT _PyInterpreterFrame
26
27// The next two functions are taken from
28// https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1182
29// These are not exported by the CPython binary and thus we have
30// to get our own implementation of them.
31// As a simple way to reduce the impact of ABI changes on the CPython side, this check forces
32// us to manually re-check that the function didn't change on the next major version
33#if PY_VERSION_HEX >= 0x030C0000 // 3.12
34#error "Please ensure that the functions below still match the CPython implementation for 3.12"
35#endif
36
37static int
38_PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg)
39{
40 // This only works when opcode is a non-quickened form:
41 assert(_PyOpcode_Deopt[opcode] == opcode);
42 int check_oparg = 0;
43 for (_Py_CODEUNIT *instruction = _PyCode_CODE(frame->f_code);
44 instruction < frame->prev_instr; instruction++)
45 {
46 int check_opcode = _PyOpcode_Deopt[_Py_OPCODE(*instruction)];
47 check_oparg |= _Py_OPARG(*instruction);
48 if (check_opcode == opcode && check_oparg == oparg) {
49 return 1;
50 }
51 if (check_opcode == EXTENDED_ARG) {
52 check_oparg <<= 8;
53 }
54 else {
55 check_oparg = 0;
56 }
57 instruction += _PyOpcode_Caches[check_opcode];
58 }
59 return 0;
60}
61
62int
63THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
64 /* Merge fast locals into f->f_locals */
65 PyObject *locals;
66 PyObject **fast;
67 PyCodeObject *co;
68 locals = frame->f_locals;
69 if (locals == NULL) {
70 locals = frame->f_locals = PyDict_New();
71 if (locals == NULL)
72 return -1;
73 }
74 co = frame->f_code;
75 fast = _PyFrame_GetLocalsArray(frame);
76 // COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt
77 // here:
78 int lasti = _PyInterpreterFrame_LASTI(frame);
79 if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) {
80 /* Free vars have not been initialized -- Do that */
81 PyCodeObject *co = frame->f_code;
82 PyObject *closure = frame->f_func->func_closure;
83 int offset = co->co_nlocals + co->co_nplaincellvars;
84 for (int i = 0; i < co->co_nfreevars; ++i) {
85 PyObject *o = PyTuple_GET_ITEM(closure, i);
86 Py_INCREF(o);
87 frame->localsplus[offset + i] = o;
88 }
89 // COPY_FREE_VARS doesn't have inline CACHEs, either:
90 frame->prev_instr = _PyCode_CODE(frame->f_code);
91 }
92 for (int i = 0; i < co->co_nlocalsplus; i++) {
93 _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);
94
95 /* If the namespace is unoptimized, then one of the
96 following cases applies:
97 1. It does not contain free variables, because it
98 uses import * or is a top-level namespace.
99 2. It is a class namespace.
100 We don't want to accidentally copy free variables
101 into the locals dict used by the class.
102 */
103 if (kind & CO_FAST_FREE && !(co->co_flags & CO_OPTIMIZED)) {
104 continue;
105 }
106
107 PyObject *name = PyTuple_GET_ITEM(co->co_localsplusnames, i);
108 PyObject *value = fast[i];
109 if (frame->stacktop) {
110 if (kind & CO_FAST_FREE) {
111 // The cell was set by COPY_FREE_VARS.
112 assert(value != NULL && PyCell_Check(value));
113 value = PyCell_GET(value);
114 }
115 else if (kind & CO_FAST_CELL) {
116 // Note that no *_DEREF ops can happen before MAKE_CELL
117 // executes. So there's no need to duplicate the work
118 // that MAKE_CELL would otherwise do later, if it hasn't
119 // run yet.
120 if (value != NULL) {
121 if (PyCell_Check(value) &&
122 _PyFrame_OpAlreadyRan(frame, MAKE_CELL, i)) {
123 // (likely) MAKE_CELL must have executed already.
124 value = PyCell_GET(value);
125 }
126 // (likely) Otherwise it it is an arg (kind & CO_FAST_LOCAL),
127 // with the initial value set when the frame was created...
128 // (unlikely) ...or it was set to some initial value by
129 // an earlier call to PyFrame_LocalsToFast().
130 }
131 }
132 }
133 else {
134 assert(value == NULL);
135 }
136 if (value == NULL) {
137 if (PyObject_DelItem(locals, name) != 0) {
138 if (PyErr_ExceptionMatches(PyExc_KeyError)) {
139 PyErr_Clear();
140 }
141 else {
142 return -1;
143 }
144 }
145 }
146 else {
147 if (PyObject_SetItem(locals, name, value) != 0) {
148 return -1;
149 }
150 }
151 }
152 return 0;
153}
154
155// We need to be able to return the _PyInterpreterFrame to python so create
156// a python binding for it
157
158typedef struct THPPyInterpreterFrame {
159 PyObject_HEAD
160 _PyInterpreterFrame* frame; // Borrowed reference
161} THPPyInterpreterFrame;
162
163THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame);
164
165#define DECLARE_PYOBJ_ATTR(name) \
166static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \
167 PyObject* res = (PyObject*)self->frame->name; \
168 Py_XINCREF(res); \
169 return res; \
170}
171
172DECLARE_PYOBJ_ATTR(f_func)
173DECLARE_PYOBJ_ATTR(f_globals)
174DECLARE_PYOBJ_ATTR(f_builtins)
175DECLARE_PYOBJ_ATTR(f_locals)
176DECLARE_PYOBJ_ATTR(f_code)
177DECLARE_PYOBJ_ATTR(frame_obj)
178
179#undef DECLARE_PYOBJ_ATTR
180
181static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) {
182 THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous);
183 return res;
184}
185
186// This is not a true attribute of the class but we do access it in python and it is hard to implement
187// on the python side, so do it here:
188static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) {
189 return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
190}
191
192// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
193static struct PyGetSetDef THPDevice_properties[] = {
194 {"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
195 {"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
196 {"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
197 {"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
198 {"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
199 {"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
200 {"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
201 {"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
202 {NULL}};
203
204PyTypeObject THPPyInterpreterFrameType = {
205 PyVarObject_HEAD_INIT(NULL, 0) "torch._C.dynamo.eval_frame._PyInterpreterFrame", /* tp_name */
206 sizeof(THPPyInterpreterFrame), /* tp_basicsize */
207 0, /* tp_itemsize */
208 NULL, /* tp_dealloc */
209 0, /* tp_vectorcall_offset */
210 NULL, /* tp_getattr */
211 NULL, /* tp_setattr */
212 NULL, /* tp_reserved */
213 NULL, /* tp_repr */
214 NULL, /* tp_as_number */
215 NULL, /* tp_as_sequence */
216 NULL, /* tp_as_mapping */
217 NULL, /* tp_hash */
218 NULL, /* tp_call */
219 NULL, /* tp_str */
220 NULL, /* tp_getattro */
221 NULL, /* tp_setattro */
222 NULL, /* tp_as_buffer */
223 Py_TPFLAGS_DEFAULT, /* tp_flags */
224 NULL, /* tp_doc */
225 NULL, /* tp_traverse */
226 NULL, /* tp_clear */
227 NULL, /* tp_richcompare */
228 0, /* tp_weaklistoffset */
229 NULL, /* tp_iter */
230 NULL, /* tp_iternext */
231 NULL, /* tp_methods */
232 NULL, /* tp_members */
233 THPDevice_properties, /* tp_getset */
234 NULL, /* tp_base */
235 NULL, /* tp_dict */
236 NULL, /* tp_descr_get */
237 NULL, /* tp_descr_set */
238 0, /* tp_dictoffset */
239 NULL, /* tp_init */
240 NULL, /* tp_alloc */
241 NULL, /* tp_new */
242};
243
244
245THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
246 PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
247 THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
248 if (!self)
249 return NULL;
250 self->frame = frame;
251 return self;
252}
253
254
255#else
256#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
257
258#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
259#endif
260
261#ifdef _WIN32
262#define unlikely(x) (x)
263#else
264#define unlikely(x) __builtin_expect((x), 0)
265#endif
266
267#define NULL_CHECK(val) \
268 if (unlikely((val) == NULL)) { \
269 fprintf(stderr, "NULL ERROR: %s:%d\n", __FILE__, __LINE__); \
270 PyErr_Print(); \
271 abort(); \
272 } else { \
273 }
274
275#define CHECK(cond) \
276 if (unlikely(!(cond))) { \
277 fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \
278 abort(); \
279 } else { \
280 }
281
282#ifdef TORCHDYNAMO_DEBUG
283
284#define DEBUG_CHECK(cond) CHECK(cond)
285#define DEBUG_NULL_CHECK(val) NULL_CHECK(val)
286#define DEBUG_TRACE(msg, ...) \
287 fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__, __VA_ARGS__)
288#define DEBUG_TRACE0(msg) \
289 fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__)
290
291#else
292
293#define DEBUG_CHECK(cond)
294#define DEBUG_NULL_CHECK(val)
295#define DEBUG_TRACE(msg, ...)
296#define DEBUG_TRACE0(msg)
297
298#endif
299
300// Flag to just run a frame normally
301#define SKIP_CODE ((void*)0x1)
302
303static PyObject* noargs = NULL; /* cached empty tuple */
304static PyObject* dotzerokey = NULL; /* ".0" */
305static PyObject* guard_fail_hook = NULL;
306static PyObject* guard_error_hook = NULL;
307
308size_t extra_index = -1;
309
310static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
311
312inline static PyObject* eval_frame_callback_get(void) {
313 void* result = PyThread_tss_get(&eval_frame_callback_key);
314 if (unlikely(result == NULL)) {
315 Py_RETURN_NONE;
316 } else {
317 return (PyObject*)result;
318 }
319}
320
321inline static void eval_frame_callback_set(PyObject* obj) {
322 PyThread_tss_set(&eval_frame_callback_key, obj);
323}
324
325static void ignored(void* obj) {}
326static PyObject* _custom_eval_frame_shim(
327 PyThreadState* tstate,
328 THP_EVAL_API_FRAME_OBJECT* frame,
329 int throw_flag);
330static PyObject* _custom_eval_frame(
331 PyThreadState* tstate,
332 THP_EVAL_API_FRAME_OBJECT* frame,
333 int throw_flag,
334 PyObject* callback);
335#if PY_VERSION_HEX >= 0x03090000
336static PyObject* custom_eval_frame_shim(
337 PyThreadState* tstate,
338 THP_EVAL_API_FRAME_OBJECT* frame,
339 int throw_flag) {
340 return _custom_eval_frame_shim(tstate, frame, throw_flag);
341}
342#else
343static PyObject* custom_eval_frame_shim(THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) {
344 PyThreadState* tstate = PyThreadState_GET();
345 return _custom_eval_frame_shim(tstate, frame, throw_flag);
346}
347#endif
348
349inline static PyObject* eval_frame_default(
350 PyThreadState* tstate,
351 THP_EVAL_API_FRAME_OBJECT* frame,
352 int throw_flag) {
353#if PY_VERSION_HEX >= 0x03090000
354 if (tstate == NULL) {
355 tstate = PyThreadState_GET();
356 }
357 return _PyEval_EvalFrameDefault(tstate, frame, throw_flag);
358#else
359 return _PyEval_EvalFrameDefault(frame, throw_flag);
360#endif
361}
362
363inline static void enable_eval_frame_shim(PyThreadState* tstate) {
364#if PY_VERSION_HEX >= 0x03090000
365 if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
366 &custom_eval_frame_shim) {
367 _PyInterpreterState_SetEvalFrameFunc(
368 tstate->interp, &custom_eval_frame_shim);
369 }
370#else
371 if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
372 // First call
373 tstate->interp->eval_frame = &custom_eval_frame_shim;
374 }
375#endif
376}
377
378inline static void enable_eval_frame_default(PyThreadState* tstate) {
379#if PY_VERSION_HEX >= 0x03090000
380 if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
381 &_PyEval_EvalFrameDefault) {
382 _PyInterpreterState_SetEvalFrameFunc(
383 tstate->interp, &_PyEval_EvalFrameDefault);
384 }
385#else
386 if (tstate->interp->eval_frame != &_PyEval_EvalFrameDefault) {
387 // First call
388 tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
389 }
390#endif
391}
392
393static inline PyObject* call_callback(
394 PyObject* callable,
395 THP_EVAL_API_FRAME_OBJECT* _frame,
396 long cache_len) {
397
398#if IS_PYTHON_3_11_PLUS
399 THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
400#else
401 PyFrameObject* frame = _frame;
402#endif
403 PyObject* args = Py_BuildValue("(Ol)", frame, cache_len);
404 if (args == NULL) {
405 return NULL;
406 }
407 PyObject* result = PyObject_CallObject(callable, args);
408 Py_DECREF(args);
409 return result;
410}
411
412typedef struct cache_entry {
413 // check the guards: lambda: <locals of user function>: bool
414 PyObject* check_fn;
415 // modified user bytecode (protected by check_fn's guards)
416 PyCodeObject* code;
417 // on a cache miss, linked list of next thing to try
418 struct cache_entry* next;
419} CacheEntry;
420
421static CacheEntry* create_cache_entry(
422 CacheEntry* next,
423 PyObject* guarded_code) {
424 CacheEntry* e = (CacheEntry*)malloc(sizeof(CacheEntry));
425 DEBUG_NULL_CHECK(e);
426 e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
427 NULL_CHECK(e->check_fn);
428 e->code = (PyCodeObject*)PyObject_GetAttrString(guarded_code, "code");
429 NULL_CHECK(e->code);
430 e->next = next;
431 return e;
432}
433
434static void destroy_cache_entry(CacheEntry* e) {
435 if (e == NULL || e == SKIP_CODE) {
436 return;
437 }
438 Py_XDECREF(e->check_fn);
439 Py_XDECREF(e->code);
440 destroy_cache_entry(e->next);
441 free(e);
442}
443
444inline static CacheEntry* get_extra(PyCodeObject* code) {
445 CacheEntry* extra = NULL;
446 _PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra);
447 return extra;
448}
449
450inline static void set_extra(PyCodeObject* code, CacheEntry* extra) {
451 // TODO(jansel): would it be faster to bypass this?
452 _PyCode_SetExtra((PyObject*)code, extra_index, extra);
453}
454
455#ifdef TORCHDYNAMO_DEBUG
456inline static const char* name(THP_EVAL_API_FRAME_OBJECT* frame) {
457 DEBUG_CHECK(PyUnicode_Check(frame->f_code->co_name));
458 return PyUnicode_AsUTF8(frame->f_code->co_name);
459}
460#endif
461
462static PyObject* call_guard_fail_hook(
463 PyObject* hook,
464 CacheEntry* e,
465 PyObject* f_locals) {
466 // call debugging logic when a guard fails
467 PyObject* args = PyTuple_Pack(
468 4,
469 e->check_fn,
470 e->code,
471 f_locals,
472 (e->next == NULL ? Py_True : Py_False));
473 if (args == NULL) return NULL;
474 PyObject* result = PyObject_CallObject(hook, args);
475 Py_DECREF(args);
476 return result;
477}
478
479// Return value: borrowed reference
480// Is either Py_None or a PyCodeObject
481static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev) {
482 if (e == NULL) {
483 // NB: intentionally not using Py_RETURN_NONE, to return borrowed ref
484 return Py_None;
485 }
486 PyObject *f_locals = frame->f_locals;
487 PyObject* dotzero = PyDict_GetItem(f_locals, dotzerokey);
488 PyObject* valid = NULL;
489 if (unlikely(dotzero != NULL)) {
490 // .0 is a special variable name used for implicit args
491 PyObject* args = PyTuple_Pack(1, dotzero);
492 if (args == NULL) return NULL;
493 valid = PyObject_Call(e->check_fn, args, f_locals);
494 Py_DECREF(args);
495 } else {
496 valid = PyObject_Call(e->check_fn, noargs, f_locals);
497 }
498 if (unlikely(valid == NULL)) {
499 if (guard_error_hook != NULL) {
500 PyObject *type, *value, *traceback;
501 PyErr_Fetch(&type, &value, &traceback);
502 PyObject* r = call_guard_fail_hook(guard_error_hook, e, f_locals);
503 if (r == NULL) {
504 return NULL;
505 }
506 Py_DECREF(r);
507 PyErr_Restore(type, value, traceback);
508 }
509 return NULL;
510 }
511 Py_DECREF(valid);
512 if (valid == Py_True) {
513 // Keep the head as the most recently used cache entry.
514 // If the hit cache entry is not the head of the linked list,
515 // move it to the head
516 if (prev != NULL) {
517 CacheEntry* extra = get_extra(frame->f_code);
518 prev->next = e->next;
519 e->next = extra;
520 set_extra(frame->f_code, e);
521 }
522 return (PyObject*)e->code;
523 }
524 if (unlikely(guard_fail_hook != NULL)) {
525 PyObject* r = call_guard_fail_hook(guard_fail_hook, e, f_locals);
526 if (r == NULL) {
527 return NULL;
528 }
529 Py_DECREF(r);
530 }
531 return lookup(e->next, frame, e);
532}
533
534static long cache_size(CacheEntry* e) {
535 if (e == NULL) {
536 return 0;
537 }
538 return 1 + cache_size(e->next);
539}
540
541inline static PyObject* eval_custom_code(
542 PyThreadState* tstate,
543 THP_EVAL_API_FRAME_OBJECT* frame,
544 PyCodeObject* code,
545 int throw_flag) {
546 Py_ssize_t ncells = 0;
547 Py_ssize_t nfrees = 0;
548 Py_ssize_t nlocals_new = code->co_nlocals;
549 Py_ssize_t nlocals_old = frame->f_code->co_nlocals;
550
551 ncells = PyCode_GetNCellvars(code);
552 nfrees = PyCode_GetNFreevars(code);
553
554 DEBUG_NULL_CHECK(tstate);
555 DEBUG_NULL_CHECK(frame);
556 DEBUG_NULL_CHECK(code);
557 DEBUG_CHECK(ncells == PyTuple_GET_SIZE(frame->f_code->co_cellvars));
558 DEBUG_CHECK(nfrees == PyTuple_GET_SIZE(frame->f_code->co_freevars));
559 DEBUG_CHECK(nlocals_new >= nlocals_old);
560
561 PyFrameObject* shadow_obj = PyFrame_New(tstate, code, frame->f_globals, NULL);
562 #if IS_PYTHON_3_11_PLUS
563 THP_EVAL_API_FRAME_OBJECT* shadow = shadow_obj->f_frame;
564 #else
565 THP_EVAL_API_FRAME_OBJECT* shadow = shadow_obj;
566 #endif
567 if (shadow == NULL) {
568 return NULL;
569 }
570
571 #if IS_PYTHON_3_11_PLUS
572 PyObject** fastlocals_old = frame->localsplus;
573 PyObject** fastlocals_new = shadow->localsplus;
574 #else
575 PyObject** fastlocals_old = frame->f_localsplus;
576 PyObject** fastlocals_new = shadow->f_localsplus;
577 #endif
578
579 for (Py_ssize_t i = 0; i < nlocals_old; i++) {
580 Py_XINCREF(fastlocals_old[i]);
581 fastlocals_new[i] = fastlocals_old[i];
582 }
583
584 for (Py_ssize_t i = 0; i < ncells + nfrees; i++) {
585 Py_XINCREF(fastlocals_old[nlocals_old + i]);
586 fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i];
587 }
588
589 PyObject* result = eval_frame_default(tstate, shadow, throw_flag);
590 Py_DECREF(shadow);
591 return result;
592}
593
594static PyObject* _custom_eval_frame_shim(
595 PyThreadState* tstate,
596 THP_EVAL_API_FRAME_OBJECT* frame,
597 int throw_flag) {
598 // Shims logic into one of three states. Can probably be refactored into a
599 // single func, later:
600 // - None: disables TorchDynamo
601 // - False: run-only mode (reuse existing compiles)
602 // - Python callable(): enables TorchDynamo
603 PyObject* callback = eval_frame_callback_get();
604
605 if (callback == Py_None) {
606 return eval_frame_default(tstate, frame, throw_flag);
607 }
608
609 return _custom_eval_frame(tstate, frame, throw_flag, callback);
610}
611
612static PyObject* _custom_eval_frame(
613 PyThreadState* tstate,
614 THP_EVAL_API_FRAME_OBJECT* frame,
615 int throw_flag,
616 PyObject* callback) {
617 DEBUG_TRACE(
618 "begin %s %s %i %i %i %i",
619 name(frame),
620 PyUnicode_AsUTF8(frame->f_code->co_filename),
621 frame->f_lineno,
622 frame->f_lasti,
623 frame->f_iblock,
624 frame->f_executing);
625 CacheEntry* extra = get_extra(frame->f_code);
626 if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) {
627 DEBUG_TRACE("skip %s", name(frame));
628 return eval_frame_default(tstate, frame, throw_flag);
629 }
630
631 // TODO(jansel): investigate directly using the "fast" representation
632 // TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
633 // even though we should pass a PyFrameObject.
634 if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
635 DEBUG_TRACE("error %s", name(frame));
636 return NULL;
637 }
638
639 // A callback of Py_False indicates "run only" mode, the cache is checked, but
640 // we never compile.
641 if (callback == Py_False) {
642 DEBUG_TRACE("In run only mode %s", name(frame));
643 PyObject* maybe_cached_code = lookup(extra, frame, NULL);
644 if (maybe_cached_code == NULL) {
645 // guard eval failed, keep propagating
646 return NULL;
647 } else if (maybe_cached_code == Py_None) {
648 DEBUG_TRACE("cache miss %s", name(frame));
649 return eval_frame_default(tstate, frame, throw_flag);
650 }
651 PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
652 // used cached version
653 DEBUG_TRACE("cache hit %s", name(frame));
654 return eval_custom_code(tstate, frame, cached_code, throw_flag);
655 }
656 DEBUG_CHECK(PyDict_CheckExact(frame->f_locals));
657 DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
658 DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins));
659
660 // We don't run the current custom_eval_frame behavior for guards.
661 // So we temporarily set the callback to Py_None to drive the correct behavior
662 // in the shim.
663 eval_frame_callback_set(Py_None);
664
665 PyObject* maybe_cached_code = lookup(extra, frame, NULL);
666 if (maybe_cached_code == NULL) {
667 // Python error
668 return NULL;
669 } else if (maybe_cached_code != Py_None) {
670 PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
671 // used cached version
672 DEBUG_TRACE("cache hit %s", name(frame));
673 // Re-enable custom behavior
674 eval_frame_callback_set(callback);
675 return eval_custom_code(tstate, frame, cached_code, throw_flag);
676 }
677 // cache miss
678
679 // TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
680 // that gets re-interpreted as a PyObject (which it is NOT!)
681 PyObject* result =
682 call_callback(callback, frame, cache_size(extra));
683 if (result == NULL) {
684 // internal exception, returning here will leak the exception into user code
685 // this is useful for debugging -- but we dont want it to happen outside of
686 // testing
687 return NULL;
688 } else if (result != Py_None) {
689 DEBUG_TRACE("create cache %s", name(frame));
690 extra = create_cache_entry(extra, result);
691 Py_DECREF(result);
692 set_extra(frame->f_code, extra);
693 // Re-enable custom behavior
694 eval_frame_callback_set(callback);
695 return eval_custom_code(tstate, frame, extra->code, throw_flag);
696 } else {
697 DEBUG_TRACE("create skip %s", name(frame));
698 Py_DECREF(result);
699 destroy_cache_entry(extra);
700 set_extra(frame->f_code, SKIP_CODE);
701 // Re-enable custom behavior
702 eval_frame_callback_set(callback);
703 return eval_frame_default(tstate, frame, throw_flag);
704 }
705}
706
707static int active_dynamo_threads = 0;
708
709static PyObject* increment_working_threads(PyThreadState* tstate) {
710 active_dynamo_threads = active_dynamo_threads + 1;
711 if (active_dynamo_threads > 0) {
712 enable_eval_frame_shim(tstate);
713 }
714 Py_RETURN_NONE;
715}
716
717static PyObject* decrement_working_threads(PyThreadState* tstate) {
718 if (active_dynamo_threads > 0) {
719 active_dynamo_threads = active_dynamo_threads - 1;
720 if (active_dynamo_threads == 0) {
721 enable_eval_frame_default(tstate);
722 }
723 }
724 Py_RETURN_NONE;
725}
726
727static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate) {
728 // Change the eval frame callback and return the old one
729 // - None: disables TorchDynamo
730 // - False: run-only mode (reuse existing compiles)
731 // - Python callable(): enables TorchDynamo
732 PyObject* old_callback = eval_frame_callback_get();
733
734 // owned by caller
735 Py_INCREF(old_callback);
736
737 if (old_callback != Py_None && new_callback == Py_None) {
738 decrement_working_threads(tstate);
739 } else if (old_callback == Py_None && new_callback != Py_None) {
740 increment_working_threads(tstate);
741 }
742
743 Py_INCREF(new_callback);
744 Py_DECREF(old_callback);
745
746 // Set thread local callback. This will drive behavior of our shim, if/when it
747 // is installed.
748 eval_frame_callback_set(new_callback);
749
750 return old_callback;
751}
752
753static PyObject* set_eval_frame_py(PyObject* dummy, PyObject* args) {
754 PyObject* callback = NULL;
755 if (!PyArg_ParseTuple(args, "O:callback", &callback)) {
756 DEBUG_TRACE0("arg error");
757 return NULL;
758 }
759 if (callback != Py_None && callback != Py_False &&
760 !PyCallable_Check(callback)) {
761 DEBUG_TRACE0("arg error");
762 PyErr_SetString(PyExc_TypeError, "expected a callable");
763 return NULL;
764 }
765 DEBUG_TRACE(
766 "python enabled=%d and is run_only=%d",
767 callback != Py_None,
768 callback == Py_False);
769 return set_eval_frame(callback, PyThreadState_GET());
770}
771
772static PyObject* reset_code(PyObject* dummy, PyObject* args) {
773 PyObject* code = NULL;
774 if (!PyArg_ParseTuple(args, "O:code", &code)) {
775 DEBUG_TRACE0("arg error");
776 return NULL;
777 }
778 if (!PyCode_Check(code)) {
779 DEBUG_TRACE0("arg error");
780 PyErr_SetString(PyExc_TypeError, "expected a code object");
781 return NULL;
782 }
783
784 destroy_cache_entry(get_extra((PyCodeObject*)code));
785 set_extra((PyCodeObject*)code, NULL);
786 Py_RETURN_NONE;
787}
788
789static PyObject* unsupported(PyObject* dummy, PyObject* args) {
790 // a dummy C function used in testing
791 PyObject* obj1 = NULL;
792 PyObject* obj2 = NULL;
793 if (!PyArg_ParseTuple(args, "OO", &obj1, &obj2)) {
794 return NULL;
795 }
796 Py_INCREF(obj2);
797 return obj2;
798}
799
800static PyObject* skip_code(PyObject* dummy, PyObject* args) {
801 PyObject* obj = NULL;
802 if (!PyArg_ParseTuple(args, "O", &obj)) {
803 return NULL;
804 }
805 if (!PyCode_Check(obj)) {
806 PyErr_SetString(PyExc_TypeError, "expected a code object");
807 return NULL;
808 }
809 set_extra((PyCodeObject*)obj, SKIP_CODE);
810 Py_RETURN_NONE;
811}
812
813static PyObject* set_guard_fail_hook(PyObject* dummy, PyObject* args) {
814 PyObject* obj = NULL;
815 if (!PyArg_ParseTuple(args, "O", &obj)) {
816 return NULL;
817 }
818 Py_XDECREF(guard_fail_hook);
819 if (obj == Py_None) {
820 guard_fail_hook = NULL;
821 } else {
822 guard_fail_hook = obj;
823 Py_INCREF(guard_fail_hook);
824 }
825 Py_RETURN_NONE;
826}
827
828static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* args) {
829 PyObject* obj = NULL;
830 if (!PyArg_ParseTuple(args, "O", &obj)) {
831 return NULL;
832 }
833 Py_XDECREF(guard_error_hook);
834 if (obj == Py_None) {
835 guard_error_hook = NULL;
836 } else {
837 guard_error_hook = obj;
838 Py_INCREF(guard_error_hook);
839 }
840 Py_RETURN_NONE;
841}
842
843static PyMethodDef _methods[] = {
844 {"set_eval_frame", set_eval_frame_py, METH_VARARGS, NULL},
845 {"reset_code", reset_code, METH_VARARGS, NULL},
846 {"unsupported", unsupported, METH_VARARGS, NULL},
847 {"skip_code", skip_code, METH_VARARGS, NULL},
848 {"set_guard_fail_hook", set_guard_fail_hook, METH_VARARGS, NULL},
849 {"set_guard_error_hook", set_guard_error_hook, METH_VARARGS, NULL},
850 {NULL, NULL, 0, NULL}};
851
852static struct PyModuleDef _module = {
853 PyModuleDef_HEAD_INIT,
854 "torch._C._dynamo.eval_frame",
855 "Module containing hooks to override eval_frame",
856 -1,
857 _methods};
858
859PyObject* torch_c_dynamo_eval_frame_init(void) {
860 extra_index = _PyEval_RequestCodeExtraIndex(ignored);
861
862 int result = PyThread_tss_create(&eval_frame_callback_key);
863 CHECK(result == 0);
864
865 Py_INCREF(Py_None);
866 eval_frame_callback_set(Py_None);
867
868 noargs = PyTuple_New(0);
869 dotzerokey = PyUnicode_InternFromString(".0");
870 PyObject* module = PyModule_Create(&_module);
871
872#if IS_PYTHON_3_11_PLUS
873 if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
874 return NULL;
875 }
876 Py_INCREF(&THPPyInterpreterFrameType);
877 if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) {
878 return NULL;
879 }
880#endif
881
882 return module;
883}
884