1 | #pragma once |
2 | |
3 | #include <torch/csrc/python_headers.h> |
4 | |
5 | #include <torch/csrc/autograd/engine.h> |
6 | #include <torch/csrc/autograd/function.h> |
7 | |
8 | bool THPEngine_initModule(PyObject* module); |
9 | |
10 | namespace torch { |
11 | namespace autograd { |
12 | namespace python { |
13 | |
14 | struct PythonEngine : public Engine { |
15 | static Engine& get_python_engine(); |
16 | ~PythonEngine() override; |
17 | void thread_init( |
18 | int device, |
19 | const std::shared_ptr<ReadyQueue>& ready_queue, |
20 | bool should_increment) override; |
21 | void thread_on_exception( |
22 | std::shared_ptr<GraphTask> graph_task, |
23 | const std::shared_ptr<Node>& fn, |
24 | std::exception& e) override; |
25 | variable_list execute( |
26 | const edge_list& roots, |
27 | const variable_list& inputs, |
28 | bool keep_graph, |
29 | bool create_graph, |
30 | bool accumulate_grad, |
31 | const edge_list& outputs = {}) override; |
32 | |
33 | c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task( |
34 | const std::shared_ptr<GraphTask>& graph_task, |
35 | std::shared_ptr<Node> graph_root, |
36 | InputBuffer&& input_buffer) override; |
37 | |
38 | std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() override; |
39 | std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() |
40 | override; |
41 | |
42 | private: |
43 | PythonEngine(); |
44 | }; |
45 | |
46 | } // namespace python |
47 | } // namespace autograd |
48 | } // namespace torch |
49 | |