1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/frontend/source_range.h> |
4 | #include <torch/csrc/jit/frontend/tracer.h> |
5 | #include <torch/csrc/python_headers.h> |
6 | #include <torch/csrc/utils/pybind.h> |
7 | |
8 | #include <memory> |
9 | #include <string> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | |
14 | struct Module; |
15 | |
16 | namespace tracer { |
17 | void initPythonTracerBindings(PyObject* module); |
18 | |
19 | SourceRange getPythonInterpreterSourceRange(); |
20 | |
21 | Node* preRecordPythonTrace( |
22 | THPObjectPtr pyobj, |
23 | const std::string& arg_types, |
24 | at::ArrayRef<autograd::Variable> inputs, |
25 | std::vector<THPObjectPtr> scalar_args); |
26 | |
27 | std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict( |
28 | const py::function& func, |
29 | const py::dict& inputs_dict, |
30 | Stack inputs, |
31 | const py::function& var_name_lookup_fn, |
32 | bool strict, |
33 | bool force_outplace, |
34 | Module* self = nullptr, |
35 | const std::vector<std::string>& argument_names = {}); |
36 | |
37 | std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing( |
38 | const py::function& func, |
39 | Stack inputs, |
40 | const py::function& var_name_lookup_fn, |
41 | bool strict, |
42 | bool force_outplace, |
43 | Module* self = nullptr, |
44 | const std::vector<std::string>& argument_names = {}); |
45 | } // namespace tracer |
46 | } // namespace jit |
47 | } // namespace torch |
48 | |