1// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#pragma once
8// note: pytorch's python variable simple includes pybind which conflicts with minpybind
9// so this file just reproduces the minimial API needed to extract Tensors from python objects.
10
11#include <torch/csrc/python_headers.h>
12#include <ATen/core/Tensor.h>
13#include <torch/csrc/Export.h>
14
15// Python object that backs torch.autograd.Variable
16// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
17struct THPVariable {
18 PyObject_HEAD;
19 // Payload
20 c10::MaybeOwned<at::Tensor> cdata;
21 // Hooks to be run on backwards pass (corresponds to Python attr
22 // '_backwards_hooks', set by 'register_hook')
23 PyObject* backward_hooks = nullptr;
24};
25
26TORCH_PYTHON_API extern PyObject *THPVariableClass;
27TORCH_PYTHON_API extern PyObject *ParameterClass;
28
29TORCH_PYTHON_API PyObject * THPVariable_Wrap(at::TensorBase var);
30
31inline bool THPVariable_Check(PyObject *obj)
32{
33 if (!THPVariableClass)
34 return false;
35
36 const auto result = PyObject_IsInstance(obj, THPVariableClass);
37 AT_ASSERT(result != -1);
38 return result;
39}
40
41inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
42 return *var->cdata;
43}
44
45inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
46 return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj));
47}
48
49TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
50