1#include <torch/csrc/autograd/python_legacy_variable.h>
2
3#include <ATen/ATen.h>
4
5#include <torch/csrc/Exceptions.h>
6#include <torch/csrc/autograd/python_function.h>
7#include <torch/csrc/autograd/python_variable.h>
8#include <torch/csrc/jit/frontend/tracer.h>
9#include <torch/csrc/tensor/python_tensor.h>
10
11using namespace at;
12
13namespace torch {
14namespace autograd {
15
16static PyObject* THPVariable_pynew(
17 PyTypeObject* type,
18 PyObject* args,
19 PyObject* kwds) {
20 HANDLE_TH_ERRORS
21 THPObjectPtr _data;
22 PyObject* data = nullptr;
23 PyObject* grad_fn = nullptr;
24 char is_volatile = 0;
25 char requires_grad = 0;
26 const char* name = nullptr;
27
28 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
29 constexpr char* accepted_args[] = {
30 "data", "requires_grad", "volatile", "_grad_fn", "name", nullptr};
31 if (!PyArg_ParseTupleAndKeywords(
32 args,
33 kwds,
34 "|ObbOz",
35 const_cast<char**>(accepted_args),
36 &data,
37 &requires_grad,
38 &is_volatile,
39 &grad_fn,
40 &name))
41 return nullptr;
42
43 if (grad_fn == Py_None)
44 grad_fn = nullptr;
45
46 if (is_volatile) {
47 auto r = PyErr_WarnEx(
48 PyExc_UserWarning,
49 "volatile was removed and now has no effect. Use `with torch.no_grad():` "
50 "instead.",
51 1);
52 if (r != 0)
53 throw python_error();
54 }
55
56 if (is_volatile && requires_grad) {
57 throw ValueError(
58 "Variable can't be volatile and require_grad at the same time!");
59 }
60 if (grad_fn && !THPFunction_Check(grad_fn)) {
61 throw TypeError(
62 "_grad_fn has to be a Function object or None, but got %s",
63 Py_TYPE(grad_fn)->tp_name);
64 }
65 Variable var;
66 if (!data || data == Py_None) {
67 // For legacy serialization code, create an empty tensor. This is also used
68 // by nn.Parameter() with no arguments.
69 auto dispatch_key = torch::tensors::get_default_dispatch_key();
70 auto scalar_type = torch::tensors::get_default_scalar_type();
71 auto options = TensorOptions(scalar_type)
72 .device(dispatchKeyToDeviceType(dispatch_key))
73 .layout(dispatchKeyToLayout(dispatch_key));
74 var = at::empty({0}, options);
75 } else if (THPVariable_Check(data)) {
76 var = THPVariable_Unpack(data).detach();
77 } else {
78 throw torch::TypeError(
79 "Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name);
80 }
81 // We set `tensor`'s `allow_tensor_metadata_change` to true here, because we
82 // want to allow the following use case for backward compatibility:
83 //
84 // ```python
85 // var = Variable(torch.randn(2, 3))
86 // var.resize_(4, 5)
87 // ```
88 var.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
89
90 TORCH_CHECK(
91 !grad_fn,
92 "_grad_fn argument to legacy Variable constructor is no longer supported. "
93 "Instead, please invoke your _grad_fn to produce a variable with it as the "
94 "_grad_fn.");
95 var.set_requires_grad(requires_grad);
96
97 if (name) {
98 impl::set_name(var, name);
99 }
100
101 if (jit::tracer::isTracing() && data && data != Py_None &&
102 THPVariable_Check(data)) {
103 if (auto* v = jit::tracer::getValueTrace(THPVariable_Unpack(data))) {
104 jit::tracer::setValueTrace(var, v);
105 }
106 }
107
108 return THPVariable_Wrap(std::move(var));
109 END_HANDLE_TH_ERRORS
110}
111
112PyTypeObject THPLegacyVariableType = {
113 PyVarObject_HEAD_INIT(
114 nullptr,
115 0) "torch._C._LegacyVariableBase", /* tp_name */
116 0, /* tp_basicsize */
117 0, /* tp_itemsize */
118 nullptr, /* tp_dealloc */
119 0, /* tp_vectorcall_offset */
120 nullptr, /* tp_getattr */
121 nullptr, /* tp_setattr */
122 nullptr, /* tp_reserved */
123 nullptr, /* tp_repr */
124 nullptr, /* tp_as_number */
125 nullptr, /* tp_as_sequence */
126 nullptr, /* tp_as_mapping */
127 nullptr, /* tp_hash */
128 nullptr, /* tp_call */
129 nullptr, /* tp_str */
130 nullptr, /* tp_getattro */
131 nullptr, /* tp_setattro */
132 nullptr, /* tp_as_buffer */
133 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
134 nullptr, /* tp_doc */
135 nullptr, /* tp_traverse */
136 nullptr, /* tp_clear */
137 nullptr, /* tp_richcompare */
138 0, /* tp_weaklistoffset */
139 nullptr, /* tp_iter */
140 nullptr, /* tp_iternext */
141 nullptr, /* tp_methods */
142 nullptr, /* tp_members */
143 nullptr, /* tp_getset */
144 nullptr, /* tp_base */
145 nullptr, /* tp_dict */
146 nullptr, /* tp_descr_get */
147 nullptr, /* tp_descr_set */
148 0, /* tp_dictoffset */
149 nullptr, /* tp_init */
150 nullptr, /* tp_alloc */
151 THPVariable_pynew /* tp_new */
152};
153
154void init_legacy_variable(PyObject* module) {
155 if (PyType_Ready(&THPLegacyVariableType) < 0) {
156 throw python_error();
157 }
158 auto obj = (PyObject*)&THPLegacyVariableType;
159 Py_INCREF(obj);
160 if (PyModule_AddObject(module, "_LegacyVariableBase", obj) < 0) {
161 throw python_error();
162 }
163}
164
165} // namespace autograd
166} // namespace torch
167