1 | #pragma once |
2 | |
3 | #include <ATen/core/Tensor.h> |
4 | #include <torch/csrc/python_headers.h> |
5 | #include <memory> |
6 | |
7 | #include <ATen/core/function_schema.h> |
8 | #include <pybind11/pybind11.h> |
9 | #include <torch/csrc/Exceptions.h> |
10 | #include <torch/csrc/Export.h> |
11 | #include <torch/csrc/autograd/variable.h> |
12 | #include <torch/csrc/utils/pybind.h> |
13 | |
14 | namespace py = pybind11; |
15 | |
16 | // Python object that backs torch.autograd.Variable |
17 | struct THPVariable { |
18 | PyObject_HEAD; |
19 | // Payload |
20 | c10::MaybeOwned<at::Tensor> cdata; |
21 | // Hooks to be run on backwards pass (corresponds to Python attr |
22 | // '_backwards_hooks', set by 'register_hook') |
23 | PyObject* backward_hooks = nullptr; |
24 | }; |
25 | |
26 | TORCH_API void registerPythonTensorClass( |
27 | const std::string& device, |
28 | PyObject* python_tensor_class); |
29 | |
30 | TORCH_API void activateCUDATrace(); |
31 | |
32 | TORCH_PYTHON_API extern PyObject* THPVariableClass; |
33 | TORCH_PYTHON_API extern PyObject* ParameterClass; |
34 | |
35 | bool THPVariable_initModule(PyObject* module); |
36 | TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var); |
37 | |
38 | static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { |
39 | // Check that a python object is a `Tensor`, but not a `Tensor` subclass. |
40 | // (A subclass could have different semantics.) The one exception is |
41 | // Parameter, which is used for Python bookkeeping but is equivalent to |
42 | // Tensor as far as C++ is concerned. |
43 | return ( |
44 | tp == (PyTypeObject*)THPVariableClass || |
45 | tp == (PyTypeObject*)ParameterClass); |
46 | } |
47 | |
48 | static inline bool THPVariable_CheckExact(PyObject* obj) { |
49 | return THPVariable_CheckTypeExact(Py_TYPE(obj)); |
50 | } |
51 | |
52 | inline bool THPVariable_Check(PyObject* obj) { |
53 | if (!THPVariableClass) |
54 | return false; |
55 | |
56 | const auto result = PyObject_IsInstance(obj, THPVariableClass); |
57 | if (result == -1) |
58 | throw python_error(); |
59 | return result; |
60 | } |
61 | |
62 | inline const at::Tensor& THPVariable_Unpack(THPVariable* var) { |
63 | return *var->cdata; |
64 | } |
65 | |
66 | inline const at::Tensor& THPVariable_Unpack(PyObject* obj) { |
67 | return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj)); |
68 | } |
69 | |
70 | std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs( |
71 | const c10::OperatorHandle& op, |
72 | const std::vector<c10::IValue>& arguments); |
73 | |
74 | void pushPyOutToStack( |
75 | const c10::OperatorHandle& op, |
76 | torch::jit::Stack* stack, |
77 | py::object out, |
78 | const char* msg); |
79 | |