1 | #define PY_SSIZE_T_CLEAN |
2 | #include <torch/csrc/utils/python_compat.h> |
3 | #include <opcode.h> |
4 | #include <stdbool.h> |
5 | |
6 | // see https://bugs.python.org/issue35886 |
7 | #if PY_VERSION_HEX >= 0x03080000 |
8 | #define Py_BUILD_CORE |
9 | #include <internal/pycore_pystate.h> |
10 | |
11 | // These headers were added in 3.11 |
12 | #if IS_PYTHON_3_11_PLUS |
13 | #include <internal/pycore_frame.h> |
14 | #define NEED_OPCODE_TABLES // To get _PyOpcode_Deopt |
15 | #include <internal/pycore_opcode.h> |
16 | #undef NEED_OPCODE_TABLES |
17 | #endif |
18 | |
19 | #undef Py_BUILD_CORE |
20 | #endif // PY_VERSION_HEX >= 0x03080000 |
21 | |
22 | // All the eval APIs change in 3.11 so we need to decide which one to use on the fly |
23 | // https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction |
24 | #if IS_PYTHON_3_11_PLUS |
25 | #define THP_EVAL_API_FRAME_OBJECT _PyInterpreterFrame |
26 | |
27 | // The next two functions are taken from |
28 | // https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1182 |
29 | // These are not exported by the CPython binary and thus we have |
30 | // to get our own implementation of them. |
31 | // As a simple way to reduce the impact of ABI changes on the CPython side, this check forces |
32 | // us to manually re-check that the function didn't change on the next major version |
33 | #if PY_VERSION_HEX >= 0x030C0000 // 3.12 |
34 | #error "Please ensure that the functions below still match the CPython implementation for 3.12" |
35 | #endif |
36 | |
37 | static int |
38 | _PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg) |
39 | { |
40 | // This only works when opcode is a non-quickened form: |
41 | assert(_PyOpcode_Deopt[opcode] == opcode); |
42 | int check_oparg = 0; |
43 | for (_Py_CODEUNIT *instruction = _PyCode_CODE(frame->f_code); |
44 | instruction < frame->prev_instr; instruction++) |
45 | { |
46 | int check_opcode = _PyOpcode_Deopt[_Py_OPCODE(*instruction)]; |
47 | check_oparg |= _Py_OPARG(*instruction); |
48 | if (check_opcode == opcode && check_oparg == oparg) { |
49 | return 1; |
50 | } |
51 | if (check_opcode == EXTENDED_ARG) { |
52 | check_oparg <<= 8; |
53 | } |
54 | else { |
55 | check_oparg = 0; |
56 | } |
57 | instruction += _PyOpcode_Caches[check_opcode]; |
58 | } |
59 | return 0; |
60 | } |
61 | |
62 | int |
63 | THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) { |
64 | /* Merge fast locals into f->f_locals */ |
65 | PyObject *locals; |
66 | PyObject **fast; |
67 | PyCodeObject *co; |
68 | locals = frame->f_locals; |
69 | if (locals == NULL) { |
70 | locals = frame->f_locals = PyDict_New(); |
71 | if (locals == NULL) |
72 | return -1; |
73 | } |
74 | co = frame->f_code; |
75 | fast = _PyFrame_GetLocalsArray(frame); |
76 | // COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt |
77 | // here: |
78 | int lasti = _PyInterpreterFrame_LASTI(frame); |
79 | if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) { |
80 | /* Free vars have not been initialized -- Do that */ |
81 | PyCodeObject *co = frame->f_code; |
82 | PyObject *closure = frame->f_func->func_closure; |
83 | int offset = co->co_nlocals + co->co_nplaincellvars; |
84 | for (int i = 0; i < co->co_nfreevars; ++i) { |
85 | PyObject *o = PyTuple_GET_ITEM(closure, i); |
86 | Py_INCREF(o); |
87 | frame->localsplus[offset + i] = o; |
88 | } |
89 | // COPY_FREE_VARS doesn't have inline CACHEs, either: |
90 | frame->prev_instr = _PyCode_CODE(frame->f_code); |
91 | } |
92 | for (int i = 0; i < co->co_nlocalsplus; i++) { |
93 | _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i); |
94 | |
95 | /* If the namespace is unoptimized, then one of the |
96 | following cases applies: |
97 | 1. It does not contain free variables, because it |
98 | uses import * or is a top-level namespace. |
99 | 2. It is a class namespace. |
100 | We don't want to accidentally copy free variables |
101 | into the locals dict used by the class. |
102 | */ |
103 | if (kind & CO_FAST_FREE && !(co->co_flags & CO_OPTIMIZED)) { |
104 | continue; |
105 | } |
106 | |
107 | PyObject *name = PyTuple_GET_ITEM(co->co_localsplusnames, i); |
108 | PyObject *value = fast[i]; |
109 | if (frame->stacktop) { |
110 | if (kind & CO_FAST_FREE) { |
111 | // The cell was set by COPY_FREE_VARS. |
112 | assert(value != NULL && PyCell_Check(value)); |
113 | value = PyCell_GET(value); |
114 | } |
115 | else if (kind & CO_FAST_CELL) { |
116 | // Note that no *_DEREF ops can happen before MAKE_CELL |
117 | // executes. So there's no need to duplicate the work |
118 | // that MAKE_CELL would otherwise do later, if it hasn't |
119 | // run yet. |
120 | if (value != NULL) { |
121 | if (PyCell_Check(value) && |
122 | _PyFrame_OpAlreadyRan(frame, MAKE_CELL, i)) { |
123 | // (likely) MAKE_CELL must have executed already. |
124 | value = PyCell_GET(value); |
125 | } |
126 | // (likely) Otherwise it it is an arg (kind & CO_FAST_LOCAL), |
127 | // with the initial value set when the frame was created... |
128 | // (unlikely) ...or it was set to some initial value by |
129 | // an earlier call to PyFrame_LocalsToFast(). |
130 | } |
131 | } |
132 | } |
133 | else { |
134 | assert(value == NULL); |
135 | } |
136 | if (value == NULL) { |
137 | if (PyObject_DelItem(locals, name) != 0) { |
138 | if (PyErr_ExceptionMatches(PyExc_KeyError)) { |
139 | PyErr_Clear(); |
140 | } |
141 | else { |
142 | return -1; |
143 | } |
144 | } |
145 | } |
146 | else { |
147 | if (PyObject_SetItem(locals, name, value) != 0) { |
148 | return -1; |
149 | } |
150 | } |
151 | } |
152 | return 0; |
153 | } |
154 | |
155 | // We need to be able to return the _PyInterpreterFrame to python so create |
156 | // a python binding for it |
157 | |
158 | typedef struct THPPyInterpreterFrame { |
159 | PyObject_HEAD |
160 | _PyInterpreterFrame* frame; // Borrowed reference |
161 | } THPPyInterpreterFrame; |
162 | |
163 | THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame); |
164 | |
165 | #define DECLARE_PYOBJ_ATTR(name) \ |
166 | static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \ |
167 | PyObject* res = (PyObject*)self->frame->name; \ |
168 | Py_XINCREF(res); \ |
169 | return res; \ |
170 | } |
171 | |
172 | DECLARE_PYOBJ_ATTR(f_func) |
173 | DECLARE_PYOBJ_ATTR(f_globals) |
174 | DECLARE_PYOBJ_ATTR(f_builtins) |
175 | DECLARE_PYOBJ_ATTR(f_locals) |
176 | DECLARE_PYOBJ_ATTR(f_code) |
177 | DECLARE_PYOBJ_ATTR(frame_obj) |
178 | |
179 | #undef DECLARE_PYOBJ_ATTR |
180 | |
181 | static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) { |
182 | THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous); |
183 | return res; |
184 | } |
185 | |
186 | // This is not a true attribute of the class but we do access it in python and it is hard to implement |
187 | // on the python side, so do it here: |
188 | static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) { |
189 | return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame)); |
190 | } |
191 | |
192 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
193 | static struct PyGetSetDef THPDevice_properties[] = { |
194 | {"f_func" , (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL}, |
195 | {"f_globals" , (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL}, |
196 | {"f_builtins" , (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL}, |
197 | {"f_locals" , (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL}, |
198 | {"f_code" , (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL}, |
199 | {"frame_obj" , (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL}, |
200 | {"previous" , (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL}, |
201 | {"f_lasti" , (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL}, |
202 | {NULL}}; |
203 | |
204 | PyTypeObject THPPyInterpreterFrameType = { |
205 | PyVarObject_HEAD_INIT(NULL, 0) "torch._C.dynamo.eval_frame._PyInterpreterFrame" , /* tp_name */ |
206 | sizeof(THPPyInterpreterFrame), /* tp_basicsize */ |
207 | 0, /* tp_itemsize */ |
208 | NULL, /* tp_dealloc */ |
209 | 0, /* tp_vectorcall_offset */ |
210 | NULL, /* tp_getattr */ |
211 | NULL, /* tp_setattr */ |
212 | NULL, /* tp_reserved */ |
213 | NULL, /* tp_repr */ |
214 | NULL, /* tp_as_number */ |
215 | NULL, /* tp_as_sequence */ |
216 | NULL, /* tp_as_mapping */ |
217 | NULL, /* tp_hash */ |
218 | NULL, /* tp_call */ |
219 | NULL, /* tp_str */ |
220 | NULL, /* tp_getattro */ |
221 | NULL, /* tp_setattro */ |
222 | NULL, /* tp_as_buffer */ |
223 | Py_TPFLAGS_DEFAULT, /* tp_flags */ |
224 | NULL, /* tp_doc */ |
225 | NULL, /* tp_traverse */ |
226 | NULL, /* tp_clear */ |
227 | NULL, /* tp_richcompare */ |
228 | 0, /* tp_weaklistoffset */ |
229 | NULL, /* tp_iter */ |
230 | NULL, /* tp_iternext */ |
231 | NULL, /* tp_methods */ |
232 | NULL, /* tp_members */ |
233 | THPDevice_properties, /* tp_getset */ |
234 | NULL, /* tp_base */ |
235 | NULL, /* tp_dict */ |
236 | NULL, /* tp_descr_get */ |
237 | NULL, /* tp_descr_set */ |
238 | 0, /* tp_dictoffset */ |
239 | NULL, /* tp_init */ |
240 | NULL, /* tp_alloc */ |
241 | NULL, /* tp_new */ |
242 | }; |
243 | |
244 | |
245 | THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) { |
246 | PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType; |
247 | THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0); |
248 | if (!self) |
249 | return NULL; |
250 | self->frame = frame; |
251 | return self; |
252 | } |
253 | |
254 | |
255 | #else |
256 | #define THP_EVAL_API_FRAME_OBJECT PyFrameObject |
257 | |
258 | #define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError |
259 | #endif |
260 | |
261 | #ifdef _WIN32 |
262 | #define unlikely(x) (x) |
263 | #else |
264 | #define unlikely(x) __builtin_expect((x), 0) |
265 | #endif |
266 | |
267 | #define NULL_CHECK(val) \ |
268 | if (unlikely((val) == NULL)) { \ |
269 | fprintf(stderr, "NULL ERROR: %s:%d\n", __FILE__, __LINE__); \ |
270 | PyErr_Print(); \ |
271 | abort(); \ |
272 | } else { \ |
273 | } |
274 | |
275 | #define CHECK(cond) \ |
276 | if (unlikely(!(cond))) { \ |
277 | fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \ |
278 | abort(); \ |
279 | } else { \ |
280 | } |
281 | |
282 | #ifdef TORCHDYNAMO_DEBUG |
283 | |
284 | #define DEBUG_CHECK(cond) CHECK(cond) |
285 | #define DEBUG_NULL_CHECK(val) NULL_CHECK(val) |
286 | #define DEBUG_TRACE(msg, ...) \ |
287 | fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__, __VA_ARGS__) |
288 | #define DEBUG_TRACE0(msg) \ |
289 | fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__) |
290 | |
291 | #else |
292 | |
293 | #define DEBUG_CHECK(cond) |
294 | #define DEBUG_NULL_CHECK(val) |
295 | #define DEBUG_TRACE(msg, ...) |
296 | #define DEBUG_TRACE0(msg) |
297 | |
298 | #endif |
299 | |
300 | // Flag to just run a frame normally |
301 | #define SKIP_CODE ((void*)0x1) |
302 | |
303 | static PyObject* noargs = NULL; /* cached empty tuple */ |
304 | static PyObject* dotzerokey = NULL; /* ".0" */ |
305 | static PyObject* guard_fail_hook = NULL; |
306 | static PyObject* guard_error_hook = NULL; |
307 | |
308 | size_t = -1; |
309 | |
310 | static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT; |
311 | |
312 | inline static PyObject* eval_frame_callback_get(void) { |
313 | void* result = PyThread_tss_get(&eval_frame_callback_key); |
314 | if (unlikely(result == NULL)) { |
315 | Py_RETURN_NONE; |
316 | } else { |
317 | return (PyObject*)result; |
318 | } |
319 | } |
320 | |
321 | inline static void eval_frame_callback_set(PyObject* obj) { |
322 | PyThread_tss_set(&eval_frame_callback_key, obj); |
323 | } |
324 | |
325 | static void ignored(void* obj) {} |
326 | static PyObject* _custom_eval_frame_shim( |
327 | PyThreadState* tstate, |
328 | THP_EVAL_API_FRAME_OBJECT* frame, |
329 | int throw_flag); |
330 | static PyObject* _custom_eval_frame( |
331 | PyThreadState* tstate, |
332 | THP_EVAL_API_FRAME_OBJECT* frame, |
333 | int throw_flag, |
334 | PyObject* callback); |
335 | #if PY_VERSION_HEX >= 0x03090000 |
336 | static PyObject* custom_eval_frame_shim( |
337 | PyThreadState* tstate, |
338 | THP_EVAL_API_FRAME_OBJECT* frame, |
339 | int throw_flag) { |
340 | return _custom_eval_frame_shim(tstate, frame, throw_flag); |
341 | } |
342 | #else |
343 | static PyObject* custom_eval_frame_shim(THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) { |
344 | PyThreadState* tstate = PyThreadState_GET(); |
345 | return _custom_eval_frame_shim(tstate, frame, throw_flag); |
346 | } |
347 | #endif |
348 | |
349 | inline static PyObject* eval_frame_default( |
350 | PyThreadState* tstate, |
351 | THP_EVAL_API_FRAME_OBJECT* frame, |
352 | int throw_flag) { |
353 | #if PY_VERSION_HEX >= 0x03090000 |
354 | if (tstate == NULL) { |
355 | tstate = PyThreadState_GET(); |
356 | } |
357 | return _PyEval_EvalFrameDefault(tstate, frame, throw_flag); |
358 | #else |
359 | return _PyEval_EvalFrameDefault(frame, throw_flag); |
360 | #endif |
361 | } |
362 | |
363 | inline static void enable_eval_frame_shim(PyThreadState* tstate) { |
364 | #if PY_VERSION_HEX >= 0x03090000 |
365 | if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) != |
366 | &custom_eval_frame_shim) { |
367 | _PyInterpreterState_SetEvalFrameFunc( |
368 | tstate->interp, &custom_eval_frame_shim); |
369 | } |
370 | #else |
371 | if (tstate->interp->eval_frame != &custom_eval_frame_shim) { |
372 | // First call |
373 | tstate->interp->eval_frame = &custom_eval_frame_shim; |
374 | } |
375 | #endif |
376 | } |
377 | |
378 | inline static void enable_eval_frame_default(PyThreadState* tstate) { |
379 | #if PY_VERSION_HEX >= 0x03090000 |
380 | if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) != |
381 | &_PyEval_EvalFrameDefault) { |
382 | _PyInterpreterState_SetEvalFrameFunc( |
383 | tstate->interp, &_PyEval_EvalFrameDefault); |
384 | } |
385 | #else |
386 | if (tstate->interp->eval_frame != &_PyEval_EvalFrameDefault) { |
387 | // First call |
388 | tstate->interp->eval_frame = &_PyEval_EvalFrameDefault; |
389 | } |
390 | #endif |
391 | } |
392 | |
393 | static inline PyObject* call_callback( |
394 | PyObject* callable, |
395 | THP_EVAL_API_FRAME_OBJECT* _frame, |
396 | long cache_len) { |
397 | |
398 | #if IS_PYTHON_3_11_PLUS |
399 | THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame); |
400 | #else |
401 | PyFrameObject* frame = _frame; |
402 | #endif |
403 | PyObject* args = Py_BuildValue("(Ol)" , frame, cache_len); |
404 | if (args == NULL) { |
405 | return NULL; |
406 | } |
407 | PyObject* result = PyObject_CallObject(callable, args); |
408 | Py_DECREF(args); |
409 | return result; |
410 | } |
411 | |
412 | typedef struct cache_entry { |
413 | // check the guards: lambda: <locals of user function>: bool |
414 | PyObject* check_fn; |
415 | // modified user bytecode (protected by check_fn's guards) |
416 | PyCodeObject* code; |
417 | // on a cache miss, linked list of next thing to try |
418 | struct cache_entry* next; |
419 | } CacheEntry; |
420 | |
421 | static CacheEntry* create_cache_entry( |
422 | CacheEntry* next, |
423 | PyObject* guarded_code) { |
424 | CacheEntry* e = (CacheEntry*)malloc(sizeof(CacheEntry)); |
425 | DEBUG_NULL_CHECK(e); |
426 | e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn" ); |
427 | NULL_CHECK(e->check_fn); |
428 | e->code = (PyCodeObject*)PyObject_GetAttrString(guarded_code, "code" ); |
429 | NULL_CHECK(e->code); |
430 | e->next = next; |
431 | return e; |
432 | } |
433 | |
434 | static void destroy_cache_entry(CacheEntry* e) { |
435 | if (e == NULL || e == SKIP_CODE) { |
436 | return; |
437 | } |
438 | Py_XDECREF(e->check_fn); |
439 | Py_XDECREF(e->code); |
440 | destroy_cache_entry(e->next); |
441 | free(e); |
442 | } |
443 | |
444 | inline static CacheEntry* (PyCodeObject* code) { |
445 | CacheEntry* = NULL; |
446 | _PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra); |
447 | return extra; |
448 | } |
449 | |
450 | inline static void (PyCodeObject* code, CacheEntry* ) { |
451 | // TODO(jansel): would it be faster to bypass this? |
452 | _PyCode_SetExtra((PyObject*)code, extra_index, extra); |
453 | } |
454 | |
455 | #ifdef TORCHDYNAMO_DEBUG |
456 | inline static const char* name(THP_EVAL_API_FRAME_OBJECT* frame) { |
457 | DEBUG_CHECK(PyUnicode_Check(frame->f_code->co_name)); |
458 | return PyUnicode_AsUTF8(frame->f_code->co_name); |
459 | } |
460 | #endif |
461 | |
462 | static PyObject* call_guard_fail_hook( |
463 | PyObject* hook, |
464 | CacheEntry* e, |
465 | PyObject* f_locals) { |
466 | // call debugging logic when a guard fails |
467 | PyObject* args = PyTuple_Pack( |
468 | 4, |
469 | e->check_fn, |
470 | e->code, |
471 | f_locals, |
472 | (e->next == NULL ? Py_True : Py_False)); |
473 | if (args == NULL) return NULL; |
474 | PyObject* result = PyObject_CallObject(hook, args); |
475 | Py_DECREF(args); |
476 | return result; |
477 | } |
478 | |
479 | // Return value: borrowed reference |
480 | // Is either Py_None or a PyCodeObject |
481 | static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev) { |
482 | if (e == NULL) { |
483 | // NB: intentionally not using Py_RETURN_NONE, to return borrowed ref |
484 | return Py_None; |
485 | } |
486 | PyObject *f_locals = frame->f_locals; |
487 | PyObject* dotzero = PyDict_GetItem(f_locals, dotzerokey); |
488 | PyObject* valid = NULL; |
489 | if (unlikely(dotzero != NULL)) { |
490 | // .0 is a special variable name used for implicit args |
491 | PyObject* args = PyTuple_Pack(1, dotzero); |
492 | if (args == NULL) return NULL; |
493 | valid = PyObject_Call(e->check_fn, args, f_locals); |
494 | Py_DECREF(args); |
495 | } else { |
496 | valid = PyObject_Call(e->check_fn, noargs, f_locals); |
497 | } |
498 | if (unlikely(valid == NULL)) { |
499 | if (guard_error_hook != NULL) { |
500 | PyObject *type, *value, *traceback; |
501 | PyErr_Fetch(&type, &value, &traceback); |
502 | PyObject* r = call_guard_fail_hook(guard_error_hook, e, f_locals); |
503 | if (r == NULL) { |
504 | return NULL; |
505 | } |
506 | Py_DECREF(r); |
507 | PyErr_Restore(type, value, traceback); |
508 | } |
509 | return NULL; |
510 | } |
511 | Py_DECREF(valid); |
512 | if (valid == Py_True) { |
513 | // Keep the head as the most recently used cache entry. |
514 | // If the hit cache entry is not the head of the linked list, |
515 | // move it to the head |
516 | if (prev != NULL) { |
517 | CacheEntry* = get_extra(frame->f_code); |
518 | prev->next = e->next; |
519 | e->next = extra; |
520 | set_extra(frame->f_code, e); |
521 | } |
522 | return (PyObject*)e->code; |
523 | } |
524 | if (unlikely(guard_fail_hook != NULL)) { |
525 | PyObject* r = call_guard_fail_hook(guard_fail_hook, e, f_locals); |
526 | if (r == NULL) { |
527 | return NULL; |
528 | } |
529 | Py_DECREF(r); |
530 | } |
531 | return lookup(e->next, frame, e); |
532 | } |
533 | |
534 | static long cache_size(CacheEntry* e) { |
535 | if (e == NULL) { |
536 | return 0; |
537 | } |
538 | return 1 + cache_size(e->next); |
539 | } |
540 | |
541 | inline static PyObject* eval_custom_code( |
542 | PyThreadState* tstate, |
543 | THP_EVAL_API_FRAME_OBJECT* frame, |
544 | PyCodeObject* code, |
545 | int throw_flag) { |
546 | Py_ssize_t ncells = 0; |
547 | Py_ssize_t nfrees = 0; |
548 | Py_ssize_t nlocals_new = code->co_nlocals; |
549 | Py_ssize_t nlocals_old = frame->f_code->co_nlocals; |
550 | |
551 | ncells = PyCode_GetNCellvars(code); |
552 | nfrees = PyCode_GetNFreevars(code); |
553 | |
554 | DEBUG_NULL_CHECK(tstate); |
555 | DEBUG_NULL_CHECK(frame); |
556 | DEBUG_NULL_CHECK(code); |
557 | DEBUG_CHECK(ncells == PyTuple_GET_SIZE(frame->f_code->co_cellvars)); |
558 | DEBUG_CHECK(nfrees == PyTuple_GET_SIZE(frame->f_code->co_freevars)); |
559 | DEBUG_CHECK(nlocals_new >= nlocals_old); |
560 | |
561 | PyFrameObject* shadow_obj = PyFrame_New(tstate, code, frame->f_globals, NULL); |
562 | #if IS_PYTHON_3_11_PLUS |
563 | THP_EVAL_API_FRAME_OBJECT* shadow = shadow_obj->f_frame; |
564 | #else |
565 | THP_EVAL_API_FRAME_OBJECT* shadow = shadow_obj; |
566 | #endif |
567 | if (shadow == NULL) { |
568 | return NULL; |
569 | } |
570 | |
571 | #if IS_PYTHON_3_11_PLUS |
572 | PyObject** fastlocals_old = frame->localsplus; |
573 | PyObject** fastlocals_new = shadow->localsplus; |
574 | #else |
575 | PyObject** fastlocals_old = frame->f_localsplus; |
576 | PyObject** fastlocals_new = shadow->f_localsplus; |
577 | #endif |
578 | |
579 | for (Py_ssize_t i = 0; i < nlocals_old; i++) { |
580 | Py_XINCREF(fastlocals_old[i]); |
581 | fastlocals_new[i] = fastlocals_old[i]; |
582 | } |
583 | |
584 | for (Py_ssize_t i = 0; i < ncells + nfrees; i++) { |
585 | Py_XINCREF(fastlocals_old[nlocals_old + i]); |
586 | fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i]; |
587 | } |
588 | |
589 | PyObject* result = eval_frame_default(tstate, shadow, throw_flag); |
590 | Py_DECREF(shadow); |
591 | return result; |
592 | } |
593 | |
594 | static PyObject* _custom_eval_frame_shim( |
595 | PyThreadState* tstate, |
596 | THP_EVAL_API_FRAME_OBJECT* frame, |
597 | int throw_flag) { |
598 | // Shims logic into one of three states. Can probably be refactored into a |
599 | // single func, later: |
600 | // - None: disables TorchDynamo |
601 | // - False: run-only mode (reuse existing compiles) |
602 | // - Python callable(): enables TorchDynamo |
603 | PyObject* callback = eval_frame_callback_get(); |
604 | |
605 | if (callback == Py_None) { |
606 | return eval_frame_default(tstate, frame, throw_flag); |
607 | } |
608 | |
609 | return _custom_eval_frame(tstate, frame, throw_flag, callback); |
610 | } |
611 | |
612 | static PyObject* _custom_eval_frame( |
613 | PyThreadState* tstate, |
614 | THP_EVAL_API_FRAME_OBJECT* frame, |
615 | int throw_flag, |
616 | PyObject* callback) { |
617 | DEBUG_TRACE( |
618 | "begin %s %s %i %i %i %i" , |
619 | name(frame), |
620 | PyUnicode_AsUTF8(frame->f_code->co_filename), |
621 | frame->f_lineno, |
622 | frame->f_lasti, |
623 | frame->f_iblock, |
624 | frame->f_executing); |
625 | CacheEntry* = get_extra(frame->f_code); |
626 | if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) { |
627 | DEBUG_TRACE("skip %s" , name(frame)); |
628 | return eval_frame_default(tstate, frame, throw_flag); |
629 | } |
630 | |
631 | // TODO(jansel): investigate directly using the "fast" representation |
632 | // TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame |
633 | // even though we should pass a PyFrameObject. |
634 | if (THP_PyFrame_FastToLocalsWithError(frame) < 0) { |
635 | DEBUG_TRACE("error %s" , name(frame)); |
636 | return NULL; |
637 | } |
638 | |
639 | // A callback of Py_False indicates "run only" mode, the cache is checked, but |
640 | // we never compile. |
641 | if (callback == Py_False) { |
642 | DEBUG_TRACE("In run only mode %s" , name(frame)); |
643 | PyObject* maybe_cached_code = lookup(extra, frame, NULL); |
644 | if (maybe_cached_code == NULL) { |
645 | // guard eval failed, keep propagating |
646 | return NULL; |
647 | } else if (maybe_cached_code == Py_None) { |
648 | DEBUG_TRACE("cache miss %s" , name(frame)); |
649 | return eval_frame_default(tstate, frame, throw_flag); |
650 | } |
651 | PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code; |
652 | // used cached version |
653 | DEBUG_TRACE("cache hit %s" , name(frame)); |
654 | return eval_custom_code(tstate, frame, cached_code, throw_flag); |
655 | } |
656 | DEBUG_CHECK(PyDict_CheckExact(frame->f_locals)); |
657 | DEBUG_CHECK(PyDict_CheckExact(frame->f_globals)); |
658 | DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins)); |
659 | |
660 | // We don't run the current custom_eval_frame behavior for guards. |
661 | // So we temporarily set the callback to Py_None to drive the correct behavior |
662 | // in the shim. |
663 | eval_frame_callback_set(Py_None); |
664 | |
665 | PyObject* maybe_cached_code = lookup(extra, frame, NULL); |
666 | if (maybe_cached_code == NULL) { |
667 | // Python error |
668 | return NULL; |
669 | } else if (maybe_cached_code != Py_None) { |
670 | PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code; |
671 | // used cached version |
672 | DEBUG_TRACE("cache hit %s" , name(frame)); |
673 | // Re-enable custom behavior |
674 | eval_frame_callback_set(callback); |
675 | return eval_custom_code(tstate, frame, cached_code, throw_flag); |
676 | } |
677 | // cache miss |
678 | |
679 | // TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame |
680 | // that gets re-interpreted as a PyObject (which it is NOT!) |
681 | PyObject* result = |
682 | call_callback(callback, frame, cache_size(extra)); |
683 | if (result == NULL) { |
684 | // internal exception, returning here will leak the exception into user code |
685 | // this is useful for debugging -- but we dont want it to happen outside of |
686 | // testing |
687 | return NULL; |
688 | } else if (result != Py_None) { |
689 | DEBUG_TRACE("create cache %s" , name(frame)); |
690 | extra = create_cache_entry(extra, result); |
691 | Py_DECREF(result); |
692 | set_extra(frame->f_code, extra); |
693 | // Re-enable custom behavior |
694 | eval_frame_callback_set(callback); |
695 | return eval_custom_code(tstate, frame, extra->code, throw_flag); |
696 | } else { |
697 | DEBUG_TRACE("create skip %s" , name(frame)); |
698 | Py_DECREF(result); |
699 | destroy_cache_entry(extra); |
700 | set_extra(frame->f_code, SKIP_CODE); |
701 | // Re-enable custom behavior |
702 | eval_frame_callback_set(callback); |
703 | return eval_frame_default(tstate, frame, throw_flag); |
704 | } |
705 | } |
706 | |
707 | static int active_dynamo_threads = 0; |
708 | |
709 | static PyObject* increment_working_threads(PyThreadState* tstate) { |
710 | active_dynamo_threads = active_dynamo_threads + 1; |
711 | if (active_dynamo_threads > 0) { |
712 | enable_eval_frame_shim(tstate); |
713 | } |
714 | Py_RETURN_NONE; |
715 | } |
716 | |
717 | static PyObject* decrement_working_threads(PyThreadState* tstate) { |
718 | if (active_dynamo_threads > 0) { |
719 | active_dynamo_threads = active_dynamo_threads - 1; |
720 | if (active_dynamo_threads == 0) { |
721 | enable_eval_frame_default(tstate); |
722 | } |
723 | } |
724 | Py_RETURN_NONE; |
725 | } |
726 | |
727 | static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate) { |
728 | // Change the eval frame callback and return the old one |
729 | // - None: disables TorchDynamo |
730 | // - False: run-only mode (reuse existing compiles) |
731 | // - Python callable(): enables TorchDynamo |
732 | PyObject* old_callback = eval_frame_callback_get(); |
733 | |
734 | // owned by caller |
735 | Py_INCREF(old_callback); |
736 | |
737 | if (old_callback != Py_None && new_callback == Py_None) { |
738 | decrement_working_threads(tstate); |
739 | } else if (old_callback == Py_None && new_callback != Py_None) { |
740 | increment_working_threads(tstate); |
741 | } |
742 | |
743 | Py_INCREF(new_callback); |
744 | Py_DECREF(old_callback); |
745 | |
746 | // Set thread local callback. This will drive behavior of our shim, if/when it |
747 | // is installed. |
748 | eval_frame_callback_set(new_callback); |
749 | |
750 | return old_callback; |
751 | } |
752 | |
753 | static PyObject* set_eval_frame_py(PyObject* dummy, PyObject* args) { |
754 | PyObject* callback = NULL; |
755 | if (!PyArg_ParseTuple(args, "O:callback" , &callback)) { |
756 | DEBUG_TRACE0("arg error" ); |
757 | return NULL; |
758 | } |
759 | if (callback != Py_None && callback != Py_False && |
760 | !PyCallable_Check(callback)) { |
761 | DEBUG_TRACE0("arg error" ); |
762 | PyErr_SetString(PyExc_TypeError, "expected a callable" ); |
763 | return NULL; |
764 | } |
765 | DEBUG_TRACE( |
766 | "python enabled=%d and is run_only=%d" , |
767 | callback != Py_None, |
768 | callback == Py_False); |
769 | return set_eval_frame(callback, PyThreadState_GET()); |
770 | } |
771 | |
772 | static PyObject* reset_code(PyObject* dummy, PyObject* args) { |
773 | PyObject* code = NULL; |
774 | if (!PyArg_ParseTuple(args, "O:code" , &code)) { |
775 | DEBUG_TRACE0("arg error" ); |
776 | return NULL; |
777 | } |
778 | if (!PyCode_Check(code)) { |
779 | DEBUG_TRACE0("arg error" ); |
780 | PyErr_SetString(PyExc_TypeError, "expected a code object" ); |
781 | return NULL; |
782 | } |
783 | |
784 | destroy_cache_entry(get_extra((PyCodeObject*)code)); |
785 | set_extra((PyCodeObject*)code, NULL); |
786 | Py_RETURN_NONE; |
787 | } |
788 | |
789 | static PyObject* unsupported(PyObject* dummy, PyObject* args) { |
790 | // a dummy C function used in testing |
791 | PyObject* obj1 = NULL; |
792 | PyObject* obj2 = NULL; |
793 | if (!PyArg_ParseTuple(args, "OO" , &obj1, &obj2)) { |
794 | return NULL; |
795 | } |
796 | Py_INCREF(obj2); |
797 | return obj2; |
798 | } |
799 | |
800 | static PyObject* skip_code(PyObject* dummy, PyObject* args) { |
801 | PyObject* obj = NULL; |
802 | if (!PyArg_ParseTuple(args, "O" , &obj)) { |
803 | return NULL; |
804 | } |
805 | if (!PyCode_Check(obj)) { |
806 | PyErr_SetString(PyExc_TypeError, "expected a code object" ); |
807 | return NULL; |
808 | } |
809 | set_extra((PyCodeObject*)obj, SKIP_CODE); |
810 | Py_RETURN_NONE; |
811 | } |
812 | |
813 | static PyObject* set_guard_fail_hook(PyObject* dummy, PyObject* args) { |
814 | PyObject* obj = NULL; |
815 | if (!PyArg_ParseTuple(args, "O" , &obj)) { |
816 | return NULL; |
817 | } |
818 | Py_XDECREF(guard_fail_hook); |
819 | if (obj == Py_None) { |
820 | guard_fail_hook = NULL; |
821 | } else { |
822 | guard_fail_hook = obj; |
823 | Py_INCREF(guard_fail_hook); |
824 | } |
825 | Py_RETURN_NONE; |
826 | } |
827 | |
828 | static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* args) { |
829 | PyObject* obj = NULL; |
830 | if (!PyArg_ParseTuple(args, "O" , &obj)) { |
831 | return NULL; |
832 | } |
833 | Py_XDECREF(guard_error_hook); |
834 | if (obj == Py_None) { |
835 | guard_error_hook = NULL; |
836 | } else { |
837 | guard_error_hook = obj; |
838 | Py_INCREF(guard_error_hook); |
839 | } |
840 | Py_RETURN_NONE; |
841 | } |
842 | |
843 | static PyMethodDef _methods[] = { |
844 | {"set_eval_frame" , set_eval_frame_py, METH_VARARGS, NULL}, |
845 | {"reset_code" , reset_code, METH_VARARGS, NULL}, |
846 | {"unsupported" , unsupported, METH_VARARGS, NULL}, |
847 | {"skip_code" , skip_code, METH_VARARGS, NULL}, |
848 | {"set_guard_fail_hook" , set_guard_fail_hook, METH_VARARGS, NULL}, |
849 | {"set_guard_error_hook" , set_guard_error_hook, METH_VARARGS, NULL}, |
850 | {NULL, NULL, 0, NULL}}; |
851 | |
852 | static struct PyModuleDef _module = { |
853 | PyModuleDef_HEAD_INIT, |
854 | "torch._C._dynamo.eval_frame" , |
855 | "Module containing hooks to override eval_frame" , |
856 | -1, |
857 | _methods}; |
858 | |
859 | PyObject* torch_c_dynamo_eval_frame_init(void) { |
860 | extra_index = _PyEval_RequestCodeExtraIndex(ignored); |
861 | |
862 | int result = PyThread_tss_create(&eval_frame_callback_key); |
863 | CHECK(result == 0); |
864 | |
865 | Py_INCREF(Py_None); |
866 | eval_frame_callback_set(Py_None); |
867 | |
868 | noargs = PyTuple_New(0); |
869 | dotzerokey = PyUnicode_InternFromString(".0" ); |
870 | PyObject* module = PyModule_Create(&_module); |
871 | |
872 | #if IS_PYTHON_3_11_PLUS |
873 | if (PyType_Ready(&THPPyInterpreterFrameType) < 0) { |
874 | return NULL; |
875 | } |
876 | Py_INCREF(&THPPyInterpreterFrameType); |
877 | if (PyModule_AddObject(module, "_PyInterpreterFrame" , (PyObject*)&THPPyInterpreterFrameType) != 0) { |
878 | return NULL; |
879 | } |
880 | #endif |
881 | |
882 | return module; |
883 | } |
884 | |