1#pragma once
2
3#include <torch/csrc/python_headers.h>
4
5#include <torch/csrc/autograd/engine.h>
6#include <torch/csrc/autograd/function.h>
7
8bool THPEngine_initModule(PyObject* module);
9
10namespace torch {
11namespace autograd {
12namespace python {
13
14struct PythonEngine : public Engine {
15 static Engine& get_python_engine();
16 ~PythonEngine() override;
17 void thread_init(
18 int device,
19 const std::shared_ptr<ReadyQueue>& ready_queue,
20 bool should_increment) override;
21 void thread_on_exception(
22 std::shared_ptr<GraphTask> graph_task,
23 const std::shared_ptr<Node>& fn,
24 std::exception& e) override;
25 variable_list execute(
26 const edge_list& roots,
27 const variable_list& inputs,
28 bool keep_graph,
29 bool create_graph,
30 bool accumulate_grad,
31 const edge_list& outputs = {}) override;
32
33 c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task(
34 const std::shared_ptr<GraphTask>& graph_task,
35 std::shared_ptr<Node> graph_root,
36 InputBuffer&& input_buffer) override;
37
38 std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() override;
39 std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks()
40 override;
41
42 private:
43 PythonEngine();
44};
45
46} // namespace python
47} // namespace autograd
48} // namespace torch
49