1#include <torch/csrc/autograd/python_engine.h>
2
3#include <ATen/LegacyBatchedTensorImpl.h>
4#include <ATen/LegacyVmapMode.h>
5#include <c10/util/irange.h>
6#include <pybind11/pybind11.h>
7#include <torch/csrc/DynamicTypes.h>
8#include <torch/csrc/THP.h>
9#include <torch/csrc/autograd/edge.h>
10#include <torch/csrc/autograd/engine.h>
11#include <torch/csrc/autograd/function.h>
12#include <torch/csrc/autograd/functions/basic_ops.h>
13#include <torch/csrc/autograd/python_anomaly_mode.h>
14#include <torch/csrc/autograd/python_function.h>
15#include <torch/csrc/autograd/python_saved_variable_hooks.h>
16#include <torch/csrc/utils/pybind.h>
17#include <torch/csrc/utils/pycfunction_helpers.h>
18
19#ifndef _WIN32
20#include <pthread.h>
21#endif
22
23#include <memory> // for unique_ptr
24#include <unordered_set>
25#include <utility>
26
27using namespace torch::autograd;
28
29struct THPEngine {
30 PyObject_HEAD
31};
32
33static bool _reinitialize_engine = false;
34
35namespace torch {
36namespace autograd {
37namespace python {
38
39PythonEngine::PythonEngine() = default;
40
41Engine& PythonEngine::get_python_engine() {
42 static PythonEngine engine;
43 // This is "probably" thread-safe because the flag is set in a fork handler
44 // before any threads are created, and this function is only called with the
45 // GIL held. However, using fork + threads is playing with fire so this is
46 // more of a "best effort" thing. For example, if the fork occurs while the
47 // backwards threads hold a lock, we'll probably deadlock in the engine
48 // destructor.
49 if (_reinitialize_engine) {
50 engine.release_workers();
51 engine.~PythonEngine();
52 new (&engine) torch::autograd::python::PythonEngine();
53 _reinitialize_engine = false;
54 }
55 return engine;
56}
57
58PythonEngine::~PythonEngine() {
59 Engine::stop();
60}
61
62#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 9
63#define IS_PYTHON_3_9_PLUS
64#endif
65
66void PythonEngine::thread_init(
67 int device,
68 const std::shared_ptr<ReadyQueue>& ready_queue,
69 bool should_increment) {
70 // Increment thread usage count before acquiring the GIL
71 if (should_increment) {
72 increment_non_reentrant_thread_count();
73 }
74 // Create a PyThreadState, but release the GIL. This lets
75 // pybind11::gil_scoped_acquire calls inside thread_main acquire the GIL
76 // without having to create a new PyThreadState each time.
77#if defined(IS_PYTHON_3_9_PLUS)
78 auto gil = std::make_unique<pybind11::gil_scoped_acquire>();
79#else
80 pybind11::gil_scoped_acquire gil;
81#endif
82 pybind11::gil_scoped_release no_gil;
83 Engine::thread_init(device, ready_queue, false);
84
85 if (should_increment) {
86 // Decrement the count during shutdown if we incremented earlier.
87 decrement_non_reentrant_thread_count();
88 }
89
90#if defined(IS_PYTHON_3_9_PLUS)
91 // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if
92 // runtime is finalizing
93 if (!Py_IsInitialized()) {
94 no_gil.disarm();
95 // TODO: call disarm rather than leak gil_scoped_acquired once
96 // PyThreadState_Clear can safely be called from finalize NOTE: deploy.cpp
97 // calls `PyInterpreterState_Delete` to destruct PyThreadState, so avoid
98 // use-after-free here.
99 gil.release();
100 }
101#endif
102}
103
104void PythonEngine::thread_on_exception(
105 std::shared_ptr<GraphTask> graph_task,
106 const std::shared_ptr<Node>& fn,
107 std::exception& e) {
108 auto python_err = dynamic_cast<python_error*>(&e);
109 if (python_err) {
110 python_err->persist();
111 }
112 Engine::thread_on_exception(std::move(graph_task), fn, e);
113}
114
115std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() {
116 return std::unique_ptr<AnomalyMetadata>(new PyAnomalyMetadata());
117}
118
119std::unique_ptr<SavedVariableHooks> PythonEngine::
120 get_default_saved_variable_hooks() {
121 return PyDefaultSavedVariableHooks::get_hooks();
122}
123
124variable_list PythonEngine::execute(
125 const edge_list& roots,
126 const variable_list& inputs,
127 bool keep_graph,
128 bool create_graph,
129 bool accumulate_grad,
130 const edge_list& outputs) {
131 TORCH_CHECK(
132 !PyGILState_Check(),
133 "The autograd engine was called while holding the GIL. If you are using the C++ "
134 "API, the autograd engine is an expensive operation that does not require the "
135 "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'"
136 ". If you are not using the C++ API, please report a bug to the pytorch team.")
137 try {
138 return Engine::execute(
139 roots, inputs, keep_graph, create_graph, accumulate_grad, outputs);
140 } catch (python_error& e) {
141 e.restore();
142 throw;
143 }
144}
145
146c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
147 const std::shared_ptr<GraphTask>& graph_task,
148 std::shared_ptr<Node> graph_root,
149 InputBuffer&& input_buffer) {
150 try {
151 return Engine::execute_with_graph_task(
152 graph_task, std::move(graph_root), std::move(input_buffer));
153 } catch (python_error& e) {
154 pybind11::gil_scoped_acquire gil;
155 if (!PyErr_Occurred()) {
156 // Set the error indicator only if it is not set already.
157 e.restore();
158 }
159 throw;
160 }
161}
162} // namespace python
163} // namespace autograd
164} // namespace torch
165
166PyObject* THPEngineClass = nullptr;
167
168// Implementation of torch._C._EngineBase.run_backward
169PyObject* THPEngine_run_backward(
170 PyObject* self,
171 PyObject* args,
172 PyObject* kwargs) {
173 HANDLE_TH_ERRORS
174 PyObject* tensors = nullptr;
175 PyObject* grad_tensors = nullptr;
176 unsigned char keep_graph = 0;
177 unsigned char create_graph = 0;
178 PyObject* inputs = nullptr;
179 unsigned char allow_unreachable = 0;
180 unsigned char accumulate_grad =
181 0; // Indicate whether to accumulate grad into leaf Tensors or capture
182 constexpr char* accepted_kwargs[] = {// NOLINT
183 "tensors",
184 "grad_tensors",
185 "keep_graph",
186 "create_graph",
187 "inputs",
188 "allow_unreachable",
189 "accumulate_grad",
190 nullptr};
191 if (!PyArg_ParseTupleAndKeywords(
192 args,
193 kwargs,
194 "OObb|Obb",
195 const_cast<char**>(accepted_kwargs),
196 &tensors,
197 &grad_tensors,
198 &keep_graph,
199 &create_graph,
200 &inputs,
201 &allow_unreachable,
202 &accumulate_grad))
203 return nullptr;
204 THPUtils_assert(
205 PyTuple_Check(tensors),
206 "tensors argument is expected to "
207 "be a tuple, but got %s",
208 THPUtils_typename(tensors));
209 THPUtils_assert(
210 PyTuple_Check(grad_tensors),
211 "grad_tensors argument is "
212 "expected to be a tuple, but got %s",
213 THPUtils_typename(grad_tensors));
214
215 Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
216 Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
217 THPUtils_assert(
218 num_tensors == num_gradients,
219 "got %ld tensors and %ld "
220 "gradients",
221 num_tensors,
222 num_gradients);
223
224 // The user either called autograd.backward(...) or autograd.grad(...) to get
225 // here
226 bool backward_api_called = accumulate_grad;
227 TORCH_CHECK(
228 !backward_api_called || at::impl::VmapMode::current_vmap_level() == 0,
229 "backward() called inside torch.vmap. This is not supported, "
230 "please call backward() outside torch.vmap or instead use "
231 "torch.autograd.grad inside torch.vmap");
232
233 edge_list roots;
234 roots.reserve(num_tensors);
235 variable_list grads;
236 grads.reserve(num_tensors);
237 for (const auto i : c10::irange(num_tensors)) {
238 PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
239 THPUtils_assert(
240 THPVariable_Check(_tensor),
241 "element %d of tensors "
242 "tuple is not a Tensor",
243 i);
244 const auto& variable = THPVariable_Unpack(_tensor);
245 TORCH_CHECK(
246 !isBatchedTensor(variable),
247 "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
248 "torch.vmap. We do not support the case where any outputs are ",
249 "vmapped tensors (output ",
250 i,
251 " is being vmapped over). Please "
252 "call autograd.grad() outside torch.vmap or file a bug report "
253 "with your use case.")
254 auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
255 THPUtils_assert(
256 gradient_edge.function,
257 "element %d of tensors does not require grad and does not have a grad_fn",
258 i);
259 roots.push_back(std::move(gradient_edge));
260
261 PyObject* grad = PyTuple_GET_ITEM(grad_tensors, i);
262 if (THPVariable_Check(grad)) {
263 const Variable& grad_var = THPVariable_Unpack(grad);
264 if (grad_var.has_names()) {
265 TORCH_WARN(
266 "Autograd was passed a named grad tensor with dims ",
267 grad_var.names(),
268 ". Autograd does not yet support named tensor semantics, so all names ",
269 "will be ignored. In practice all computed gradients will still be correct "
270 "according to regular tensor semantics.");
271 }
272 grads.push_back(grad_var);
273 } else {
274 THPUtils_assert(
275 grad == Py_None,
276 "element %d of gradients tuple is not a Tensor or None",
277 i);
278 THPUtils_assert(
279 !variable.requires_grad(),
280 "element %d of gradients tuple is None, but the corresponding Tensor requires grad");
281 }
282 }
283
284 std::vector<Edge> output_edges;
285 if (inputs != nullptr) {
286 int num_inputs = PyTuple_GET_SIZE(inputs);
287 output_edges.reserve(num_inputs);
288 for (const auto i : c10::irange(num_inputs)) {
289 PyObject* input = PyTuple_GET_ITEM(inputs, i);
290 THPUtils_assert(
291 THPVariable_Check(input),
292 "all inputs have to be Tensors, but got %s",
293 THPUtils_typename(input));
294 const auto& tensor = THPVariable_Unpack(input);
295 TORCH_CHECK(
296 !isBatchedTensor(tensor),
297 "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
298 "torch.vmap. We do not support the case where any inputs are ",
299 "vmapped tensors (input ",
300 i,
301 " is being vmapped over). Please "
302 "call autograd.grad() outside torch.vmap or file a bug report "
303 "with your use case.")
304 const auto output_nr = tensor.output_nr();
305 auto grad_fn = tensor.grad_fn();
306 if (!grad_fn) {
307 grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
308 }
309 if (accumulate_grad) {
310 tensor.retain_grad();
311 }
312 THPUtils_assert(
313 tensor.requires_grad(),
314 "One of the differentiated Tensors does not require grad");
315 if (!grad_fn) {
316 // NOTE [ Autograd Unreachable Input ]
317 // Since input has no grad_accumulator, its guaranteed to be
318 // unreachable. We initialize an edge pointing to a non-nullptr Node so
319 // nodes in the graph (e.g., mul when an operand is scalar) that have
320 // edges pointing to nullptr don't get erroneously assigned `needed =
321 // True` in exec_info.
322 output_edges.emplace_back(std::make_shared<Identity>(), 0);
323 } else {
324 output_edges.emplace_back(grad_fn, output_nr);
325 }
326 }
327 }
328
329 variable_list outputs;
330 {
331 pybind11::gil_scoped_release no_gil;
332 auto& engine = python::PythonEngine::get_python_engine();
333 outputs = engine.execute(
334 roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
335 }
336
337 if (!backward_api_called && inputs != nullptr) {
338 int num_inputs = PyTuple_GET_SIZE(inputs);
339 THPObjectPtr py_outputs{PyTuple_New(num_inputs)};
340 if (!py_outputs)
341 return nullptr;
342 for (const auto i : c10::irange(num_inputs)) {
343 THPUtils_assert(
344 allow_unreachable || outputs[i].defined(),
345 "One of the "
346 "differentiated Tensors appears to not have been used "
347 "in the graph. Set allow_unused=True if this is the "
348 "desired behavior.");
349 PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
350 }
351 return py_outputs.release();
352 } else {
353 Py_RETURN_NONE;
354 }
355 END_HANDLE_TH_ERRORS
356}
357
358PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) {
359 HANDLE_TH_ERRORS
360 auto& engine = python::PythonEngine::get_python_engine();
361 std::shared_ptr<PyObject> callback(_callback, [](PyObject* obj) {
362 pybind11::gil_scoped_acquire gil;
363 Py_DECREF(obj);
364 });
365 Py_INCREF(_callback);
366 engine.queue_callback([callback]() {
367 pybind11::gil_scoped_acquire gil;
368 THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
369 if (!result)
370 throw python_error();
371 });
372 Py_RETURN_NONE;
373 END_HANDLE_TH_ERRORS
374}
375
376PyObject* THPEngine_is_checkpoint_valid(PyObject* self, PyObject* noargs) {
377 HANDLE_TH_ERRORS
378 auto& engine = python::PythonEngine::get_python_engine();
379 if (engine.is_checkpoint_valid()) {
380 Py_RETURN_TRUE;
381 } else {
382 Py_RETURN_FALSE;
383 }
384 END_HANDLE_TH_ERRORS
385}
386
387PyObject* THPEngine_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
388 return type->tp_alloc(type, 0);
389}
390
391// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
392static struct PyMethodDef THPEngine_methods[] = {
393 {(char*)"run_backward",
394 castPyCFunctionWithKeywords(THPEngine_run_backward),
395 METH_VARARGS | METH_KEYWORDS,
396 nullptr},
397 {(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr},
398 {(char*)"is_checkpoint_valid",
399 THPEngine_is_checkpoint_valid,
400 METH_NOARGS,
401 nullptr},
402 {nullptr}};
403
404PyTypeObject THPEngineType = {
405 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._EngineBase", /* tp_name */
406 sizeof(THPEngine), /* tp_basicsize */
407 0, /* tp_itemsize */
408 nullptr, /* tp_dealloc */
409 0, /* tp_vectorcall_offset */
410 nullptr, /* tp_getattr */
411 nullptr, /* tp_setattr */
412 nullptr, /* tp_reserved */
413 nullptr, /* tp_repr */
414 nullptr, /* tp_as_number */
415 nullptr, /* tp_as_sequence */
416 nullptr, /* tp_as_mapping */
417 nullptr, /* tp_hash */
418 nullptr, /* tp_call */
419 nullptr, /* tp_str */
420 nullptr, /* tp_getattro */
421 nullptr, /* tp_setattro */
422 nullptr, /* tp_as_buffer */
423 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
424 nullptr, /* tp_doc */
425 nullptr, /* tp_traverse */
426 nullptr, /* tp_clear */
427 nullptr, /* tp_richcompare */
428 0, /* tp_weaklistoffset */
429 nullptr, /* tp_iter */
430 nullptr, /* tp_iternext */
431 THPEngine_methods, /* tp_methods */
432 nullptr, /* tp_members */
433 nullptr, /* tp_getset */
434 nullptr, /* tp_base */
435 nullptr, /* tp_dict */
436 nullptr, /* tp_descr_get */
437 nullptr, /* tp_descr_set */
438 0, /* tp_dictoffset */
439 nullptr, /* tp_init */
440 nullptr, /* tp_alloc */
441 THPEngine_new /* tp_new */
442};
443
444static void child_atfork() {
445 _reinitialize_engine = true;
446}
447
448bool THPEngine_initModule(PyObject* module) {
449#ifndef _WIN32
450 if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
451 throw std::runtime_error("unable to set pthread_atfork handler");
452 }
453#endif
454 if (PyType_Ready(&THPEngineType) < 0)
455 return false;
456 Py_INCREF(&THPEngineType);
457 PyModule_AddObject(module, "_ImperativeEngine", (PyObject*)&THPEngineType);
458 set_default_engine_stub(python::PythonEngine::get_python_engine);
459 return true;
460}
461