1 | // Copyright (c) Facebook, Inc. and its affiliates. |
---|---|
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | |
7 | #pragma once |
8 | // note: pytorch's python variable simple includes pybind which conflicts with minpybind |
9 | // so this file just reproduces the minimial API needed to extract Tensors from python objects. |
10 | |
11 | #include <torch/csrc/python_headers.h> |
12 | #include <ATen/core/Tensor.h> |
13 | #include <torch/csrc/Export.h> |
14 | |
15 | // Python object that backs torch.autograd.Variable |
16 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
17 | struct THPVariable { |
18 | PyObject_HEAD; |
19 | // Payload |
20 | c10::MaybeOwned<at::Tensor> cdata; |
21 | // Hooks to be run on backwards pass (corresponds to Python attr |
22 | // '_backwards_hooks', set by 'register_hook') |
23 | PyObject* backward_hooks = nullptr; |
24 | }; |
25 | |
26 | TORCH_PYTHON_API extern PyObject *THPVariableClass; |
27 | TORCH_PYTHON_API extern PyObject *ParameterClass; |
28 | |
29 | TORCH_PYTHON_API PyObject * THPVariable_Wrap(at::TensorBase var); |
30 | |
31 | inline bool THPVariable_Check(PyObject *obj) |
32 | { |
33 | if (!THPVariableClass) |
34 | return false; |
35 | |
36 | const auto result = PyObject_IsInstance(obj, THPVariableClass); |
37 | AT_ASSERT(result != -1); |
38 | return result; |
39 | } |
40 | |
41 | inline const at::Tensor& THPVariable_Unpack(THPVariable* var) { |
42 | return *var->cdata; |
43 | } |
44 | |
45 | inline const at::Tensor& THPVariable_Unpack(PyObject* obj) { |
46 | return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj)); |
47 | } |
48 | |
49 | TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); |
50 |