1#pragma once
2
3#include <ATen/core/Tensor.h>
4#include <torch/csrc/python_headers.h>
5#include <memory>
6
7#include <ATen/core/function_schema.h>
8#include <pybind11/pybind11.h>
9#include <torch/csrc/Exceptions.h>
10#include <torch/csrc/Export.h>
11#include <torch/csrc/autograd/variable.h>
12#include <torch/csrc/utils/pybind.h>
13
14namespace py = pybind11;
15
16// Python object that backs torch.autograd.Variable
17struct THPVariable {
18 PyObject_HEAD;
19 // Payload
20 c10::MaybeOwned<at::Tensor> cdata;
21 // Hooks to be run on backwards pass (corresponds to Python attr
22 // '_backwards_hooks', set by 'register_hook')
23 PyObject* backward_hooks = nullptr;
24};
25
26TORCH_API void registerPythonTensorClass(
27 const std::string& device,
28 PyObject* python_tensor_class);
29
30TORCH_API void activateCUDATrace();
31
32TORCH_PYTHON_API extern PyObject* THPVariableClass;
33TORCH_PYTHON_API extern PyObject* ParameterClass;
34
35bool THPVariable_initModule(PyObject* module);
36TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var);
37
38static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
39 // Check that a python object is a `Tensor`, but not a `Tensor` subclass.
40 // (A subclass could have different semantics.) The one exception is
41 // Parameter, which is used for Python bookkeeping but is equivalent to
42 // Tensor as far as C++ is concerned.
43 return (
44 tp == (PyTypeObject*)THPVariableClass ||
45 tp == (PyTypeObject*)ParameterClass);
46}
47
48static inline bool THPVariable_CheckExact(PyObject* obj) {
49 return THPVariable_CheckTypeExact(Py_TYPE(obj));
50}
51
52inline bool THPVariable_Check(PyObject* obj) {
53 if (!THPVariableClass)
54 return false;
55
56 const auto result = PyObject_IsInstance(obj, THPVariableClass);
57 if (result == -1)
58 throw python_error();
59 return result;
60}
61
62inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
63 return *var->cdata;
64}
65
66inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
67 return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj));
68}
69
70std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
71 const c10::OperatorHandle& op,
72 const std::vector<c10::IValue>& arguments);
73
74void pushPyOutToStack(
75 const c10::OperatorHandle& op,
76 torch::jit::Stack* stack,
77 py::object out,
78 const char* msg);
79