1 | #include <torch/csrc/autograd/python_engine.h> |
2 | |
3 | #include <ATen/LegacyBatchedTensorImpl.h> |
4 | #include <ATen/LegacyVmapMode.h> |
5 | #include <c10/util/irange.h> |
6 | #include <pybind11/pybind11.h> |
7 | #include <torch/csrc/DynamicTypes.h> |
8 | #include <torch/csrc/THP.h> |
9 | #include <torch/csrc/autograd/edge.h> |
10 | #include <torch/csrc/autograd/engine.h> |
11 | #include <torch/csrc/autograd/function.h> |
12 | #include <torch/csrc/autograd/functions/basic_ops.h> |
13 | #include <torch/csrc/autograd/python_anomaly_mode.h> |
14 | #include <torch/csrc/autograd/python_function.h> |
15 | #include <torch/csrc/autograd/python_saved_variable_hooks.h> |
16 | #include <torch/csrc/utils/pybind.h> |
17 | #include <torch/csrc/utils/pycfunction_helpers.h> |
18 | |
19 | #ifndef _WIN32 |
20 | #include <pthread.h> |
21 | #endif |
22 | |
23 | #include <memory> // for unique_ptr |
24 | #include <unordered_set> |
25 | #include <utility> |
26 | |
27 | using namespace torch::autograd; |
28 | |
29 | struct THPEngine { |
30 | PyObject_HEAD |
31 | }; |
32 | |
33 | static bool _reinitialize_engine = false; |
34 | |
35 | namespace torch { |
36 | namespace autograd { |
37 | namespace python { |
38 | |
39 | PythonEngine::PythonEngine() = default; |
40 | |
41 | Engine& PythonEngine::get_python_engine() { |
42 | static PythonEngine engine; |
43 | // This is "probably" thread-safe because the flag is set in a fork handler |
44 | // before any threads are created, and this function is only called with the |
45 | // GIL held. However, using fork + threads is playing with fire so this is |
46 | // more of a "best effort" thing. For example, if the fork occurs while the |
47 | // backwards threads hold a lock, we'll probably deadlock in the engine |
48 | // destructor. |
49 | if (_reinitialize_engine) { |
50 | engine.release_workers(); |
51 | engine.~PythonEngine(); |
52 | new (&engine) torch::autograd::python::PythonEngine(); |
53 | _reinitialize_engine = false; |
54 | } |
55 | return engine; |
56 | } |
57 | |
58 | PythonEngine::~PythonEngine() { |
59 | Engine::stop(); |
60 | } |
61 | |
62 | #if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 9 |
63 | #define IS_PYTHON_3_9_PLUS |
64 | #endif |
65 | |
66 | void PythonEngine::thread_init( |
67 | int device, |
68 | const std::shared_ptr<ReadyQueue>& ready_queue, |
69 | bool should_increment) { |
70 | // Increment thread usage count before acquiring the GIL |
71 | if (should_increment) { |
72 | increment_non_reentrant_thread_count(); |
73 | } |
74 | // Create a PyThreadState, but release the GIL. This lets |
75 | // pybind11::gil_scoped_acquire calls inside thread_main acquire the GIL |
76 | // without having to create a new PyThreadState each time. |
77 | #if defined(IS_PYTHON_3_9_PLUS) |
78 | auto gil = std::make_unique<pybind11::gil_scoped_acquire>(); |
79 | #else |
80 | pybind11::gil_scoped_acquire gil; |
81 | #endif |
82 | pybind11::gil_scoped_release no_gil; |
83 | Engine::thread_init(device, ready_queue, false); |
84 | |
85 | if (should_increment) { |
86 | // Decrement the count during shutdown if we incremented earlier. |
87 | decrement_non_reentrant_thread_count(); |
88 | } |
89 | |
90 | #if defined(IS_PYTHON_3_9_PLUS) |
91 | // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if |
92 | // runtime is finalizing |
93 | if (!Py_IsInitialized()) { |
94 | no_gil.disarm(); |
95 | // TODO: call disarm rather than leak gil_scoped_acquired once |
96 | // PyThreadState_Clear can safely be called from finalize NOTE: deploy.cpp |
97 | // calls `PyInterpreterState_Delete` to destruct PyThreadState, so avoid |
98 | // use-after-free here. |
99 | gil.release(); |
100 | } |
101 | #endif |
102 | } |
103 | |
104 | void PythonEngine::thread_on_exception( |
105 | std::shared_ptr<GraphTask> graph_task, |
106 | const std::shared_ptr<Node>& fn, |
107 | std::exception& e) { |
108 | auto python_err = dynamic_cast<python_error*>(&e); |
109 | if (python_err) { |
110 | python_err->persist(); |
111 | } |
112 | Engine::thread_on_exception(std::move(graph_task), fn, e); |
113 | } |
114 | |
115 | std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() { |
116 | return std::unique_ptr<AnomalyMetadata>(new PyAnomalyMetadata()); |
117 | } |
118 | |
119 | std::unique_ptr<SavedVariableHooks> PythonEngine:: |
120 | get_default_saved_variable_hooks() { |
121 | return PyDefaultSavedVariableHooks::get_hooks(); |
122 | } |
123 | |
124 | variable_list PythonEngine::execute( |
125 | const edge_list& roots, |
126 | const variable_list& inputs, |
127 | bool keep_graph, |
128 | bool create_graph, |
129 | bool accumulate_grad, |
130 | const edge_list& outputs) { |
131 | TORCH_CHECK( |
132 | !PyGILState_Check(), |
133 | "The autograd engine was called while holding the GIL. If you are using the C++ " |
134 | "API, the autograd engine is an expensive operation that does not require the " |
135 | "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'" |
136 | ". If you are not using the C++ API, please report a bug to the pytorch team." ) |
137 | try { |
138 | return Engine::execute( |
139 | roots, inputs, keep_graph, create_graph, accumulate_grad, outputs); |
140 | } catch (python_error& e) { |
141 | e.restore(); |
142 | throw; |
143 | } |
144 | } |
145 | |
146 | c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task( |
147 | const std::shared_ptr<GraphTask>& graph_task, |
148 | std::shared_ptr<Node> graph_root, |
149 | InputBuffer&& input_buffer) { |
150 | try { |
151 | return Engine::execute_with_graph_task( |
152 | graph_task, std::move(graph_root), std::move(input_buffer)); |
153 | } catch (python_error& e) { |
154 | pybind11::gil_scoped_acquire gil; |
155 | if (!PyErr_Occurred()) { |
156 | // Set the error indicator only if it is not set already. |
157 | e.restore(); |
158 | } |
159 | throw; |
160 | } |
161 | } |
162 | } // namespace python |
163 | } // namespace autograd |
164 | } // namespace torch |
165 | |
166 | PyObject* THPEngineClass = nullptr; |
167 | |
168 | // Implementation of torch._C._EngineBase.run_backward |
169 | PyObject* THPEngine_run_backward( |
170 | PyObject* self, |
171 | PyObject* args, |
172 | PyObject* kwargs) { |
173 | HANDLE_TH_ERRORS |
174 | PyObject* tensors = nullptr; |
175 | PyObject* grad_tensors = nullptr; |
176 | unsigned char keep_graph = 0; |
177 | unsigned char create_graph = 0; |
178 | PyObject* inputs = nullptr; |
179 | unsigned char allow_unreachable = 0; |
180 | unsigned char accumulate_grad = |
181 | 0; // Indicate whether to accumulate grad into leaf Tensors or capture |
182 | constexpr char* accepted_kwargs[] = {// NOLINT |
183 | "tensors" , |
184 | "grad_tensors" , |
185 | "keep_graph" , |
186 | "create_graph" , |
187 | "inputs" , |
188 | "allow_unreachable" , |
189 | "accumulate_grad" , |
190 | nullptr}; |
191 | if (!PyArg_ParseTupleAndKeywords( |
192 | args, |
193 | kwargs, |
194 | "OObb|Obb" , |
195 | const_cast<char**>(accepted_kwargs), |
196 | &tensors, |
197 | &grad_tensors, |
198 | &keep_graph, |
199 | &create_graph, |
200 | &inputs, |
201 | &allow_unreachable, |
202 | &accumulate_grad)) |
203 | return nullptr; |
204 | THPUtils_assert( |
205 | PyTuple_Check(tensors), |
206 | "tensors argument is expected to " |
207 | "be a tuple, but got %s" , |
208 | THPUtils_typename(tensors)); |
209 | THPUtils_assert( |
210 | PyTuple_Check(grad_tensors), |
211 | "grad_tensors argument is " |
212 | "expected to be a tuple, but got %s" , |
213 | THPUtils_typename(grad_tensors)); |
214 | |
215 | Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors); |
216 | Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors); |
217 | THPUtils_assert( |
218 | num_tensors == num_gradients, |
219 | "got %ld tensors and %ld " |
220 | "gradients" , |
221 | num_tensors, |
222 | num_gradients); |
223 | |
224 | // The user either called autograd.backward(...) or autograd.grad(...) to get |
225 | // here |
226 | bool backward_api_called = accumulate_grad; |
227 | TORCH_CHECK( |
228 | !backward_api_called || at::impl::VmapMode::current_vmap_level() == 0, |
229 | "backward() called inside torch.vmap. This is not supported, " |
230 | "please call backward() outside torch.vmap or instead use " |
231 | "torch.autograd.grad inside torch.vmap" ); |
232 | |
233 | edge_list roots; |
234 | roots.reserve(num_tensors); |
235 | variable_list grads; |
236 | grads.reserve(num_tensors); |
237 | for (const auto i : c10::irange(num_tensors)) { |
238 | PyObject* _tensor = PyTuple_GET_ITEM(tensors, i); |
239 | THPUtils_assert( |
240 | THPVariable_Check(_tensor), |
241 | "element %d of tensors " |
242 | "tuple is not a Tensor" , |
243 | i); |
244 | const auto& variable = THPVariable_Unpack(_tensor); |
245 | TORCH_CHECK( |
246 | !isBatchedTensor(variable), |
247 | "torch.autograd.grad(outputs, inputs, grad_outputs) called inside " , |
248 | "torch.vmap. We do not support the case where any outputs are " , |
249 | "vmapped tensors (output " , |
250 | i, |
251 | " is being vmapped over). Please " |
252 | "call autograd.grad() outside torch.vmap or file a bug report " |
253 | "with your use case." ) |
254 | auto gradient_edge = torch::autograd::impl::gradient_edge(variable); |
255 | THPUtils_assert( |
256 | gradient_edge.function, |
257 | "element %d of tensors does not require grad and does not have a grad_fn" , |
258 | i); |
259 | roots.push_back(std::move(gradient_edge)); |
260 | |
261 | PyObject* grad = PyTuple_GET_ITEM(grad_tensors, i); |
262 | if (THPVariable_Check(grad)) { |
263 | const Variable& grad_var = THPVariable_Unpack(grad); |
264 | if (grad_var.has_names()) { |
265 | TORCH_WARN( |
266 | "Autograd was passed a named grad tensor with dims " , |
267 | grad_var.names(), |
268 | ". Autograd does not yet support named tensor semantics, so all names " , |
269 | "will be ignored. In practice all computed gradients will still be correct " |
270 | "according to regular tensor semantics." ); |
271 | } |
272 | grads.push_back(grad_var); |
273 | } else { |
274 | THPUtils_assert( |
275 | grad == Py_None, |
276 | "element %d of gradients tuple is not a Tensor or None" , |
277 | i); |
278 | THPUtils_assert( |
279 | !variable.requires_grad(), |
280 | "element %d of gradients tuple is None, but the corresponding Tensor requires grad" ); |
281 | } |
282 | } |
283 | |
284 | std::vector<Edge> output_edges; |
285 | if (inputs != nullptr) { |
286 | int num_inputs = PyTuple_GET_SIZE(inputs); |
287 | output_edges.reserve(num_inputs); |
288 | for (const auto i : c10::irange(num_inputs)) { |
289 | PyObject* input = PyTuple_GET_ITEM(inputs, i); |
290 | THPUtils_assert( |
291 | THPVariable_Check(input), |
292 | "all inputs have to be Tensors, but got %s" , |
293 | THPUtils_typename(input)); |
294 | const auto& tensor = THPVariable_Unpack(input); |
295 | TORCH_CHECK( |
296 | !isBatchedTensor(tensor), |
297 | "torch.autograd.grad(outputs, inputs, grad_outputs) called inside " , |
298 | "torch.vmap. We do not support the case where any inputs are " , |
299 | "vmapped tensors (input " , |
300 | i, |
301 | " is being vmapped over). Please " |
302 | "call autograd.grad() outside torch.vmap or file a bug report " |
303 | "with your use case." ) |
304 | const auto output_nr = tensor.output_nr(); |
305 | auto grad_fn = tensor.grad_fn(); |
306 | if (!grad_fn) { |
307 | grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor); |
308 | } |
309 | if (accumulate_grad) { |
310 | tensor.retain_grad(); |
311 | } |
312 | THPUtils_assert( |
313 | tensor.requires_grad(), |
314 | "One of the differentiated Tensors does not require grad" ); |
315 | if (!grad_fn) { |
316 | // NOTE [ Autograd Unreachable Input ] |
317 | // Since input has no grad_accumulator, its guaranteed to be |
318 | // unreachable. We initialize an edge pointing to a non-nullptr Node so |
319 | // nodes in the graph (e.g., mul when an operand is scalar) that have |
320 | // edges pointing to nullptr don't get erroneously assigned `needed = |
321 | // True` in exec_info. |
322 | output_edges.emplace_back(std::make_shared<Identity>(), 0); |
323 | } else { |
324 | output_edges.emplace_back(grad_fn, output_nr); |
325 | } |
326 | } |
327 | } |
328 | |
329 | variable_list outputs; |
330 | { |
331 | pybind11::gil_scoped_release no_gil; |
332 | auto& engine = python::PythonEngine::get_python_engine(); |
333 | outputs = engine.execute( |
334 | roots, grads, keep_graph, create_graph, accumulate_grad, output_edges); |
335 | } |
336 | |
337 | if (!backward_api_called && inputs != nullptr) { |
338 | int num_inputs = PyTuple_GET_SIZE(inputs); |
339 | THPObjectPtr py_outputs{PyTuple_New(num_inputs)}; |
340 | if (!py_outputs) |
341 | return nullptr; |
342 | for (const auto i : c10::irange(num_inputs)) { |
343 | THPUtils_assert( |
344 | allow_unreachable || outputs[i].defined(), |
345 | "One of the " |
346 | "differentiated Tensors appears to not have been used " |
347 | "in the graph. Set allow_unused=True if this is the " |
348 | "desired behavior." ); |
349 | PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i])); |
350 | } |
351 | return py_outputs.release(); |
352 | } else { |
353 | Py_RETURN_NONE; |
354 | } |
355 | END_HANDLE_TH_ERRORS |
356 | } |
357 | |
358 | PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) { |
359 | HANDLE_TH_ERRORS |
360 | auto& engine = python::PythonEngine::get_python_engine(); |
361 | std::shared_ptr<PyObject> callback(_callback, [](PyObject* obj) { |
362 | pybind11::gil_scoped_acquire gil; |
363 | Py_DECREF(obj); |
364 | }); |
365 | Py_INCREF(_callback); |
366 | engine.queue_callback([callback]() { |
367 | pybind11::gil_scoped_acquire gil; |
368 | THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)}; |
369 | if (!result) |
370 | throw python_error(); |
371 | }); |
372 | Py_RETURN_NONE; |
373 | END_HANDLE_TH_ERRORS |
374 | } |
375 | |
376 | PyObject* THPEngine_is_checkpoint_valid(PyObject* self, PyObject* noargs) { |
377 | HANDLE_TH_ERRORS |
378 | auto& engine = python::PythonEngine::get_python_engine(); |
379 | if (engine.is_checkpoint_valid()) { |
380 | Py_RETURN_TRUE; |
381 | } else { |
382 | Py_RETURN_FALSE; |
383 | } |
384 | END_HANDLE_TH_ERRORS |
385 | } |
386 | |
387 | PyObject* THPEngine_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) { |
388 | return type->tp_alloc(type, 0); |
389 | } |
390 | |
391 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
392 | static struct PyMethodDef THPEngine_methods[] = { |
393 | {(char*)"run_backward" , |
394 | castPyCFunctionWithKeywords(THPEngine_run_backward), |
395 | METH_VARARGS | METH_KEYWORDS, |
396 | nullptr}, |
397 | {(char*)"queue_callback" , THPEngine_queue_callback, METH_O, nullptr}, |
398 | {(char*)"is_checkpoint_valid" , |
399 | THPEngine_is_checkpoint_valid, |
400 | METH_NOARGS, |
401 | nullptr}, |
402 | {nullptr}}; |
403 | |
404 | PyTypeObject THPEngineType = { |
405 | PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._EngineBase" , /* tp_name */ |
406 | sizeof(THPEngine), /* tp_basicsize */ |
407 | 0, /* tp_itemsize */ |
408 | nullptr, /* tp_dealloc */ |
409 | 0, /* tp_vectorcall_offset */ |
410 | nullptr, /* tp_getattr */ |
411 | nullptr, /* tp_setattr */ |
412 | nullptr, /* tp_reserved */ |
413 | nullptr, /* tp_repr */ |
414 | nullptr, /* tp_as_number */ |
415 | nullptr, /* tp_as_sequence */ |
416 | nullptr, /* tp_as_mapping */ |
417 | nullptr, /* tp_hash */ |
418 | nullptr, /* tp_call */ |
419 | nullptr, /* tp_str */ |
420 | nullptr, /* tp_getattro */ |
421 | nullptr, /* tp_setattro */ |
422 | nullptr, /* tp_as_buffer */ |
423 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
424 | nullptr, /* tp_doc */ |
425 | nullptr, /* tp_traverse */ |
426 | nullptr, /* tp_clear */ |
427 | nullptr, /* tp_richcompare */ |
428 | 0, /* tp_weaklistoffset */ |
429 | nullptr, /* tp_iter */ |
430 | nullptr, /* tp_iternext */ |
431 | THPEngine_methods, /* tp_methods */ |
432 | nullptr, /* tp_members */ |
433 | nullptr, /* tp_getset */ |
434 | nullptr, /* tp_base */ |
435 | nullptr, /* tp_dict */ |
436 | nullptr, /* tp_descr_get */ |
437 | nullptr, /* tp_descr_set */ |
438 | 0, /* tp_dictoffset */ |
439 | nullptr, /* tp_init */ |
440 | nullptr, /* tp_alloc */ |
441 | THPEngine_new /* tp_new */ |
442 | }; |
443 | |
444 | static void child_atfork() { |
445 | _reinitialize_engine = true; |
446 | } |
447 | |
448 | bool THPEngine_initModule(PyObject* module) { |
449 | #ifndef _WIN32 |
450 | if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) { |
451 | throw std::runtime_error("unable to set pthread_atfork handler" ); |
452 | } |
453 | #endif |
454 | if (PyType_Ready(&THPEngineType) < 0) |
455 | return false; |
456 | Py_INCREF(&THPEngineType); |
457 | PyModule_AddObject(module, "_ImperativeEngine" , (PyObject*)&THPEngineType); |
458 | set_default_engine_stub(python::PythonEngine::get_python_engine); |
459 | return true; |
460 | } |
461 | |