1#pragma once
2
3#include <torch/csrc/jit/frontend/source_range.h>
4#include <torch/csrc/jit/frontend/tracer.h>
5#include <torch/csrc/python_headers.h>
6#include <torch/csrc/utils/pybind.h>
7
8#include <memory>
9#include <string>
10
11namespace torch {
12namespace jit {
13
14struct Module;
15
16namespace tracer {
17void initPythonTracerBindings(PyObject* module);
18
19SourceRange getPythonInterpreterSourceRange();
20
21Node* preRecordPythonTrace(
22 THPObjectPtr pyobj,
23 const std::string& arg_types,
24 at::ArrayRef<autograd::Variable> inputs,
25 std::vector<THPObjectPtr> scalar_args);
26
27std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
28 const py::function& func,
29 const py::dict& inputs_dict,
30 Stack inputs,
31 const py::function& var_name_lookup_fn,
32 bool strict,
33 bool force_outplace,
34 Module* self = nullptr,
35 const std::vector<std::string>& argument_names = {});
36
37std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
38 const py::function& func,
39 Stack inputs,
40 const py::function& var_name_lookup_fn,
41 bool strict,
42 bool force_outplace,
43 Module* self = nullptr,
44 const std::vector<std::string>& argument_names = {});
45} // namespace tracer
46} // namespace jit
47} // namespace torch
48