1 | #include <torch/csrc/autograd/python_legacy_variable.h> |
2 | |
3 | #include <ATen/ATen.h> |
4 | |
5 | #include <torch/csrc/Exceptions.h> |
6 | #include <torch/csrc/autograd/python_function.h> |
7 | #include <torch/csrc/autograd/python_variable.h> |
8 | #include <torch/csrc/jit/frontend/tracer.h> |
9 | #include <torch/csrc/tensor/python_tensor.h> |
10 | |
11 | using namespace at; |
12 | |
13 | namespace torch { |
14 | namespace autograd { |
15 | |
16 | static PyObject* THPVariable_pynew( |
17 | PyTypeObject* type, |
18 | PyObject* args, |
19 | PyObject* kwds) { |
20 | HANDLE_TH_ERRORS |
21 | THPObjectPtr _data; |
22 | PyObject* data = nullptr; |
23 | PyObject* grad_fn = nullptr; |
24 | char is_volatile = 0; |
25 | char requires_grad = 0; |
26 | const char* name = nullptr; |
27 | |
28 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
29 | constexpr char* accepted_args[] = { |
30 | "data" , "requires_grad" , "volatile" , "_grad_fn" , "name" , nullptr}; |
31 | if (!PyArg_ParseTupleAndKeywords( |
32 | args, |
33 | kwds, |
34 | "|ObbOz" , |
35 | const_cast<char**>(accepted_args), |
36 | &data, |
37 | &requires_grad, |
38 | &is_volatile, |
39 | &grad_fn, |
40 | &name)) |
41 | return nullptr; |
42 | |
43 | if (grad_fn == Py_None) |
44 | grad_fn = nullptr; |
45 | |
46 | if (is_volatile) { |
47 | auto r = PyErr_WarnEx( |
48 | PyExc_UserWarning, |
49 | "volatile was removed and now has no effect. Use `with torch.no_grad():` " |
50 | "instead." , |
51 | 1); |
52 | if (r != 0) |
53 | throw python_error(); |
54 | } |
55 | |
56 | if (is_volatile && requires_grad) { |
57 | throw ValueError( |
58 | "Variable can't be volatile and require_grad at the same time!" ); |
59 | } |
60 | if (grad_fn && !THPFunction_Check(grad_fn)) { |
61 | throw TypeError( |
62 | "_grad_fn has to be a Function object or None, but got %s" , |
63 | Py_TYPE(grad_fn)->tp_name); |
64 | } |
65 | Variable var; |
66 | if (!data || data == Py_None) { |
67 | // For legacy serialization code, create an empty tensor. This is also used |
68 | // by nn.Parameter() with no arguments. |
69 | auto dispatch_key = torch::tensors::get_default_dispatch_key(); |
70 | auto scalar_type = torch::tensors::get_default_scalar_type(); |
71 | auto options = TensorOptions(scalar_type) |
72 | .device(dispatchKeyToDeviceType(dispatch_key)) |
73 | .layout(dispatchKeyToLayout(dispatch_key)); |
74 | var = at::empty({0}, options); |
75 | } else if (THPVariable_Check(data)) { |
76 | var = THPVariable_Unpack(data).detach(); |
77 | } else { |
78 | throw torch::TypeError( |
79 | "Variable data has to be a tensor, but got %s" , Py_TYPE(data)->tp_name); |
80 | } |
81 | // We set `tensor`'s `allow_tensor_metadata_change` to true here, because we |
82 | // want to allow the following use case for backward compatibility: |
83 | // |
84 | // ```python |
85 | // var = Variable(torch.randn(2, 3)) |
86 | // var.resize_(4, 5) |
87 | // ``` |
88 | var.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true); |
89 | |
90 | TORCH_CHECK( |
91 | !grad_fn, |
92 | "_grad_fn argument to legacy Variable constructor is no longer supported. " |
93 | "Instead, please invoke your _grad_fn to produce a variable with it as the " |
94 | "_grad_fn." ); |
95 | var.set_requires_grad(requires_grad); |
96 | |
97 | if (name) { |
98 | impl::set_name(var, name); |
99 | } |
100 | |
101 | if (jit::tracer::isTracing() && data && data != Py_None && |
102 | THPVariable_Check(data)) { |
103 | if (auto* v = jit::tracer::getValueTrace(THPVariable_Unpack(data))) { |
104 | jit::tracer::setValueTrace(var, v); |
105 | } |
106 | } |
107 | |
108 | return THPVariable_Wrap(std::move(var)); |
109 | END_HANDLE_TH_ERRORS |
110 | } |
111 | |
112 | PyTypeObject THPLegacyVariableType = { |
113 | PyVarObject_HEAD_INIT( |
114 | nullptr, |
115 | 0) "torch._C._LegacyVariableBase" , /* tp_name */ |
116 | 0, /* tp_basicsize */ |
117 | 0, /* tp_itemsize */ |
118 | nullptr, /* tp_dealloc */ |
119 | 0, /* tp_vectorcall_offset */ |
120 | nullptr, /* tp_getattr */ |
121 | nullptr, /* tp_setattr */ |
122 | nullptr, /* tp_reserved */ |
123 | nullptr, /* tp_repr */ |
124 | nullptr, /* tp_as_number */ |
125 | nullptr, /* tp_as_sequence */ |
126 | nullptr, /* tp_as_mapping */ |
127 | nullptr, /* tp_hash */ |
128 | nullptr, /* tp_call */ |
129 | nullptr, /* tp_str */ |
130 | nullptr, /* tp_getattro */ |
131 | nullptr, /* tp_setattro */ |
132 | nullptr, /* tp_as_buffer */ |
133 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
134 | nullptr, /* tp_doc */ |
135 | nullptr, /* tp_traverse */ |
136 | nullptr, /* tp_clear */ |
137 | nullptr, /* tp_richcompare */ |
138 | 0, /* tp_weaklistoffset */ |
139 | nullptr, /* tp_iter */ |
140 | nullptr, /* tp_iternext */ |
141 | nullptr, /* tp_methods */ |
142 | nullptr, /* tp_members */ |
143 | nullptr, /* tp_getset */ |
144 | nullptr, /* tp_base */ |
145 | nullptr, /* tp_dict */ |
146 | nullptr, /* tp_descr_get */ |
147 | nullptr, /* tp_descr_set */ |
148 | 0, /* tp_dictoffset */ |
149 | nullptr, /* tp_init */ |
150 | nullptr, /* tp_alloc */ |
151 | THPVariable_pynew /* tp_new */ |
152 | }; |
153 | |
154 | void init_legacy_variable(PyObject* module) { |
155 | if (PyType_Ready(&THPLegacyVariableType) < 0) { |
156 | throw python_error(); |
157 | } |
158 | auto obj = (PyObject*)&THPLegacyVariableType; |
159 | Py_INCREF(obj); |
160 | if (PyModule_AddObject(module, "_LegacyVariableBase" , obj) < 0) { |
161 | throw python_error(); |
162 | } |
163 | } |
164 | |
165 | } // namespace autograd |
166 | } // namespace torch |
167 | |