1 | #include <ATen/NamedTensorUtils.h> |
2 | #include <c10/core/DeviceType.h> |
3 | #include <c10/core/impl/GPUTrace.h> |
4 | #include <c10/core/impl/HermeticPyObjectTLS.h> |
5 | #include <c10/core/impl/PythonDispatcherTLS.h> |
6 | #include <c10/util/irange.h> |
7 | #include <pybind11/pytypes.h> |
8 | #include <torch/csrc/Device.h> |
9 | #include <torch/csrc/DynamicTypes.h> |
10 | #include <torch/csrc/Exceptions.h> |
11 | #include <torch/csrc/PyInterpreter.h> |
12 | #include <torch/csrc/Size.h> |
13 | #include <torch/csrc/THP.h> |
14 | #include <torch/csrc/Types.h> |
15 | #include <torch/csrc/autograd/autograd.h> |
16 | #include <torch/csrc/autograd/edge.h> |
17 | #include <torch/csrc/autograd/function.h> |
18 | #include <torch/csrc/autograd/python_cpp_function.h> |
19 | #include <torch/csrc/autograd/python_hook.h> |
20 | #include <torch/csrc/autograd/python_variable_indexing.h> |
21 | #include <torch/csrc/autograd/utils/error_messages.h> |
22 | #include <torch/csrc/autograd/utils/wrap_outputs.h> |
23 | #include <torch/csrc/autograd/variable.h> |
24 | #include <torch/csrc/jit/frontend/tracer.h> |
25 | #include <torch/csrc/jit/python/pybind_utils.h> |
26 | #include <torch/csrc/tensor/python_tensor.h> |
27 | #include <torch/csrc/utils/pybind.h> |
28 | #include <torch/csrc/utils/pycfunction_helpers.h> |
29 | #include <torch/csrc/utils/python_arg_parser.h> |
30 | #include <torch/csrc/utils/python_dispatch.h> |
31 | #include <torch/csrc/utils/python_strings.h> |
32 | #include <torch/csrc/utils/tensor_new.h> |
33 | #include <torch/csrc/utils/tensor_numpy.h> |
34 | |
35 | #include <torch/csrc/utils/torch_dispatch_mode.h> |
36 | |
37 | #include <ATen/ATen.h> |
38 | |
39 | #include <c10/core/SymIntArrayRef.h> |
40 | #include <structmember.h> |
41 | #include <cstdint> |
42 | #include <iostream> |
43 | #include <memory> |
44 | #include <utility> |
45 | #include <vector> |
46 | |
47 | using namespace at; |
48 | using namespace torch; |
49 | using namespace torch::autograd; |
50 | |
51 | std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs( |
52 | const c10::OperatorHandle& op, |
53 | const std::vector<c10::IValue>& arguments) { |
54 | TORCH_CHECK( |
55 | PyGILState_Check(), |
56 | "GIL must be held before you call parseIValuesToPyArgsKwargs" ); |
57 | const auto& schema = op.schema(); |
58 | py::dict kwargs; |
59 | // About all the pointers: |
60 | // |
61 | // f(int x, int y = 0, *, int z = 0) |
62 | // ^- arguments.size() |
63 | // ^- kwarg_only_start |
64 | // ^- positional_default_start |
65 | // ^- 0 |
66 | |
67 | // Find the split point between kwarg-only and regular. Since most functions |
68 | // don't have kwarg-only arguments, it is more efficient to scan from the |
69 | // right (but ideally, this would just be precomputed in FunctionSchema |
70 | // itself). (NB: minus one in the loop is because we're testing if the |
71 | // *next* argument is kwarg-only before we advance the starting index) |
72 | int64_t kwarg_only_start = arguments.size(); |
73 | for (; kwarg_only_start > 0; kwarg_only_start--) { |
74 | const auto& arg = schema.arguments()[kwarg_only_start - 1]; |
75 | if (!arg.kwarg_only()) { |
76 | break; |
77 | } |
78 | } |
79 | |
80 | // Find the first positional argument that isn't defaulted |
81 | auto is_default = [&](int64_t idx) -> bool { |
82 | const auto& arg = schema.arguments()[idx]; |
83 | if (!arg.default_value().has_value()) { |
84 | return false; |
85 | } |
86 | const auto& default_ivalue = *arg.default_value(); |
87 | const auto& ivalue = arguments[idx]; |
88 | if (default_ivalue != ivalue) { |
89 | return false; |
90 | } |
91 | return true; |
92 | }; |
93 | |
94 | int64_t positional_default_start = kwarg_only_start; |
95 | for (; positional_default_start > 0; positional_default_start--) { |
96 | if (!is_default(positional_default_start - 1)) { |
97 | break; |
98 | } |
99 | } |
100 | |
101 | auto args = |
102 | py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start)); |
103 | |
104 | auto schemaAwareToPyObject = [&](int64_t idx) -> py::object { |
105 | const auto& arg = schema.arguments()[idx]; |
106 | auto match = [&](c10::TypeKind kind) { |
107 | const auto& t = arg.real_type(); |
108 | if (t->kind() == kind) |
109 | return true; |
110 | if (auto opt_t = t->cast<c10::OptionalType>()) { |
111 | if (opt_t->getElementType()->kind() == kind) |
112 | return true; |
113 | } |
114 | return false; |
115 | }; |
116 | if (arguments[idx].isNone()) { |
117 | return py::none(); |
118 | } else if (match(c10::ScalarTypeType::Kind)) { |
119 | auto* obj = |
120 | getTHPDtype(static_cast<c10::ScalarType>(arguments[idx].toInt())); |
121 | return py::reinterpret_borrow<py::object>( |
122 | reinterpret_cast<PyObject*>(obj)); |
123 | } else if (match(c10::LayoutType::Kind)) { |
124 | auto* obj = |
125 | getTHPLayout(static_cast<c10::Layout>(arguments[idx].toInt())); |
126 | return py::reinterpret_borrow<py::object>( |
127 | reinterpret_cast<PyObject*>(obj)); |
128 | } else if (match(c10::MemoryFormatType::Kind)) { |
129 | return py::cast(static_cast<c10::MemoryFormat>(arguments[idx].toInt())); |
130 | } else { |
131 | return torch::jit::toPyObject(arguments[idx]); |
132 | } |
133 | }; |
134 | |
135 | // Populate positional arguments |
136 | for (const auto idx : c10::irange(positional_default_start)) { |
137 | PyTuple_SET_ITEM( |
138 | args.ptr(), idx, schemaAwareToPyObject(idx).release().ptr()); |
139 | } |
140 | |
141 | // Populate keyword arguments |
142 | for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) { |
143 | // But don't populate default keyword arguments |
144 | if (is_default(idx)) |
145 | continue; |
146 | const auto& arg = schema.arguments()[idx]; |
147 | kwargs[py::cast(arg.name())] = schemaAwareToPyObject(idx); |
148 | } |
149 | return std::make_pair(std::move(args), std::move(kwargs)); |
150 | } |
151 | |
152 | void pushPyOutToStack( |
153 | const c10::OperatorHandle& op, |
154 | torch::jit::Stack* stack, |
155 | py::object out, |
156 | const char* msg) { |
157 | TORCH_CHECK( |
158 | PyGILState_Check(), "GIL must be held before you call pushPyOutToStack" ); |
159 | auto schema_returns = op.schema().returns(); |
160 | const auto num_returns = schema_returns.size(); |
161 | if (num_returns == 0) { |
162 | // Check that we got a None return from Python. Anything else is an error. |
163 | TORCH_CHECK( |
164 | out.is_none(), |
165 | "Expected " , |
166 | msg, |
167 | " for " , |
168 | op.operator_name(), |
169 | " to return None but it returned something else instead." ); |
170 | } else if (num_returns == 1) { |
171 | torch::jit::push( |
172 | stack, torch::jit::toIValue(out.ptr(), schema_returns[0].type())); |
173 | } else { |
174 | auto outs = py::cast<py::sequence>(out); |
175 | for (const auto idx : c10::irange(outs.size())) { |
176 | torch::jit::push( |
177 | stack, |
178 | torch::jit::toIValue(outs[idx].ptr(), schema_returns[idx].type())); |
179 | } |
180 | } |
181 | } |
182 | |
183 | namespace { |
184 | |
185 | c10::TensorImpl::SizesStridesPolicy parseSizesStridesPolicyArgument( |
186 | c10::string_view arg) { |
187 | if (arg == "strides" ) { |
188 | return c10::TensorImpl::SizesStridesPolicy::CustomStrides; |
189 | } |
190 | |
191 | if (arg == "sizes" ) { |
192 | return c10::TensorImpl::SizesStridesPolicy::CustomSizes; |
193 | } |
194 | |
195 | TORCH_CHECK_VALUE( |
196 | false, |
197 | "Unknown sizes_strides_policy: " , |
198 | arg, |
199 | "; expected 'strides' or 'sizes'" ); |
200 | } |
201 | } // anonymous namespace |
202 | |
203 | PyObject* THPVariableClass = nullptr; |
204 | |
205 | PyObject* ParameterClass = nullptr; |
206 | |
207 | static PyObject* THPVariable_NewWithVar( |
208 | PyTypeObject* type, |
209 | Variable _var, |
210 | c10::impl::PyInterpreterStatus status, |
211 | bool allow_preexisting_pyobj = false); |
212 | |
213 | // clang-tidy gets confused by static const |
214 | static const char* VOLATILE_WARNING = |
215 | "volatile was removed and now has no effect. Use " |
216 | "`with torch.no_grad():` instead." ; |
217 | |
218 | static bool check_has_torch_dispatch(PyObject* obj) { |
219 | PyTypeObject* tp = Py_TYPE(obj); |
220 | if (THPVariable_CheckTypeExact(tp)) { |
221 | return false; |
222 | } |
223 | py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__" ); |
224 | return ( |
225 | attr.ptr() != nullptr && |
226 | attr.ptr() != torch::disabled_torch_dispatch_impl()); |
227 | } |
228 | |
229 | // NOLINTNEXTLINE |
230 | static PyObject* device_to_py_class_[static_cast<size_t>( |
231 | c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)]; |
232 | |
233 | void registerPythonTensorClass( |
234 | const std::string& device, |
235 | PyObject* python_tensor_class) { |
236 | c10::Device dev(device); |
237 | |
238 | TORCH_CHECK( |
239 | dev.type() == kXLA, "Only the python class for XLA can be overriden" ); |
240 | if (device_to_py_class_[static_cast<size_t>(dev.type())] != nullptr) { |
241 | TORCH_WARN( |
242 | "Overriding a previously registered python class for " , dev.str()); |
243 | } |
244 | |
245 | device_to_py_class_[static_cast<size_t>(dev.type())] = python_tensor_class; |
246 | } |
247 | |
248 | static PyObject* getPythonTensorClass(c10::Device d) { |
249 | return device_to_py_class_[static_cast<size_t>(d.type())]; |
250 | } |
251 | |
252 | void activateCUDATrace() { |
253 | c10::impl::GPUTrace::set_trace(getPyInterpreter()); |
254 | } |
255 | |
256 | // TODO: Make this take Variable by const reference |
257 | PyObject* THPVariable_Wrap(at::TensorBase var) { |
258 | if (!var.defined()) { |
259 | Py_RETURN_NONE; |
260 | } |
261 | |
262 | if (c10::impl::HermeticPyObjectTLS::get_state()) { |
263 | return THPVariable_NewWithVar( |
264 | (PyTypeObject*)THPVariableClass, |
265 | std::move(var), |
266 | c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); |
267 | } |
268 | |
269 | c10::optional<PyObject*> mb_obj = |
270 | var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter()); |
271 | c10::impl::PyInterpreterStatus status; |
272 | if (mb_obj.has_value()) { |
273 | auto obj = *mb_obj; |
274 | if (obj) { |
275 | if (var.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { |
276 | // C++ owns the Python object; this implies there weren't any other |
277 | // owning references to the Python object. Since we're making the |
278 | // object "live" again on Python side, let's flip back the ownership |
279 | // (Python owns C++) as it would now be unsound to deallocate the C++ |
280 | // object if all C++ references go to zero |
281 | var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false); |
282 | reinterpret_cast<THPVariable*>(obj)->cdata = |
283 | MaybeOwned<Variable>::owned(std::move(var)); |
284 | // NB: incref is not necessary, because we are "stealing" the previous |
285 | // ownership from the Variable to return it here for the wrap |
286 | return obj; |
287 | } |
288 | Py_INCREF(obj); |
289 | return obj; |
290 | } |
291 | // TODO: a better invariant is that if we tagged, we MUST have a valid |
292 | // PyObject. That's PyObject preservation |
293 | // (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR |
294 | // being a thing, the PyObject field will get cleared when all references |
295 | // to the Python object are removed. |
296 | status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; |
297 | } else { |
298 | // Assumption: if a Tensor has been shared across threads, this induces |
299 | // a refcount bump. Therefore, if the use count 1, we are the sole thread |
300 | // with access to this tensor and no race is possible. |
301 | if (var.use_count() <= 1) { |
302 | status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED; |
303 | } else { |
304 | status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; |
305 | } |
306 | } |
307 | |
308 | if (C10_LIKELY(var.device().type() != c10::kXLA)) { |
309 | return THPVariable_NewWithVar( |
310 | (PyTypeObject*)THPVariableClass, std::move(var), status); |
311 | } |
312 | |
313 | if (auto clazz = getPythonTensorClass(var.device())) { |
314 | return THPVariable_NewWithVar((PyTypeObject*)clazz, std::move(var), status); |
315 | } |
316 | |
317 | return THPVariable_NewWithVar( |
318 | (PyTypeObject*)THPVariableClass, std::move(var), status); |
319 | } |
320 | |
321 | bool isResurrectable(THPVariable* self) { |
322 | // We want to divide this check into 2 cases. |
323 | |
324 | // 1. C++ owns PyObject (in this case, self->cdata.unsafeIsBorrowed() is |
325 | // true). You might think that in this case, it is impossible for tp_clear to |
326 | // be called: surely the C++ reference to the PyObject is keeping it live? And |
327 | // you'd be right! In fact, when C++ owns the PyObject, we have an invariant |
328 | // that the refcount on the PyObject should be precisely one (because if you |
329 | // take out another reference to the PyObject, we're supposed to flip the |
330 | // ownership pointer back). In reality, you can violate this invariant |
331 | // temporarily with weak references, so we don't test for it in asserts. |
332 | |
333 | // 2. PyObject owns C++ (in this case, self->cdata.unsafeIsBorrowed() is |
334 | // false). In this case, tp_clear can get called if the PyObject is referenced |
335 | // from a dead cycle, and nowhere else. But if resurrection did not occur, |
336 | // then the reference to C++ from the PyObject must be the ONLY reference to |
337 | // the C++ object. |
338 | if (self->cdata.unsafeIsBorrowed()) { |
339 | return false; |
340 | } |
341 | auto const& tensor = THPVariable_Unpack(self); |
342 | // Check if this is hermetic. If it is, no resurrection. |
343 | if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( |
344 | getPyInterpreter()) != c10::make_optional((PyObject*)self)) { |
345 | return false; |
346 | } |
347 | if (!tensor.defined() || tensor.use_count() <= 1) { |
348 | return false; |
349 | } |
350 | return true; |
351 | } |
352 | |
353 | // returns true if successfully rezzed; if so, cancel the |
354 | // rest of deallocation |
355 | static bool THPVariable_tryResurrect(THPVariable* self) { |
356 | const auto& tensor = THPVariable_Unpack(self); |
357 | |
358 | if (!isResurrectable(self)) { |
359 | return false; |
360 | } |
361 | |
362 | // At this point, we are definitely going to resurrect the tensor. So, the |
363 | // tensor better be defined :) |
364 | TORCH_INTERNAL_ASSERT(tensor.defined()); |
365 | |
366 | // There are other C++ owners of the tensor. Flip ownership |
367 | // so that C++ owns this Python object, and cancel deallocation. |
368 | TORCH_INTERNAL_ASSERT( |
369 | !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()); |
370 | |
371 | tensor.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(true); |
372 | |
373 | // Resurrect the Python object. This is something CPython does |
374 | // internally occasionally, see |
375 | // https://github.com/python/cpython/blob/b98eba5bc2ffbe7a0ed49d540ebc4f756ae61985/Objects/object.c#L248-L259 |
376 | // so we just copy the pattern here. Note that we don't have to worry |
377 | // about saving and restoring the refcount (as the quoted code does) |
378 | // because we actually DO need to reset the refcount to one here, we |
379 | // can't assume that some other code has taken care of it. |
380 | // NB: this will overreport _Py_RefTotal but based on inspection of object.c |
381 | // there is no way to avoid this |
382 | #ifdef Py_TRACE_REFS |
383 | _Py_AddToAllObjects(reinterpret_cast<PyObject*>(self), 1); |
384 | #endif |
385 | Py_INCREF(self); |
386 | |
387 | // Flip THPVariable to be non-owning |
388 | // (near use-after-free miss here: fresh MaybeOwned is created breaking |
389 | // reference on Tensor in struct BEFORE we overwrite the old one) |
390 | TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state()); |
391 | self->cdata = MaybeOwned<Variable>::borrowed(tensor); |
392 | |
393 | // NB: At this point, tensor *could* be dead (e.g., some other C++ thread |
394 | // decrefed it.) At this point, it is probably waiting on the GIL to |
395 | // deallocate the Python object and will kill self, BUT NOT YET. |
396 | |
397 | return true; |
398 | } |
399 | |
400 | static int THPVariable_clear(THPVariable* self) { |
401 | // Is it OK for an object to still be live after running |
402 | // tp_clear? Yes. When Python is breaking reference cycles, it can't assume |
403 | // that an object will dealloc after it's cleared. The source code explicitly |
404 | // handles this case: |
405 | // https://github.com/python/cpython/blob/4e661cd69164318c1f871faa476c68a04092ddc4/Modules/gcmodule.c#L1010-L1025 |
406 | |
407 | // Note that we don't need to actually resurrect here. There are 2 cases: |
408 | // 1. The PyObject is not part of a reference cycle. In this case, we don't |
409 | // need to do anything. The GC will move on to try and break the reference |
410 | // cycle on another object, which will eventually trigger tp_dealloc (and thus |
411 | // resurrection). |
412 | |
413 | // 2. The PyObject is part of a reference cycle. This case should not actually |
414 | // be possible, due to the logic in our tp_traverse |
415 | // (THPVariable_subclass_traverse). |
416 | |
417 | // In fact, resurrecting here breaks the invariant that "C++ owns Python only |
418 | // when PyObject's refcount would otherwise be 0". Most immediately, as we're |
419 | // merely breaking reference cycles here, there can be other references to the |
420 | // PyObject. *However*, if other objects in the refcycle resurrect, then we |
421 | // will be in a state where the PyObject has multiple Python references, yet |
422 | // C++ owns the PyObject. |
423 | |
424 | // See https://github.com/pytorch/pytorch/pull/75933 for more discussion. |
425 | if (isResurrectable((THPVariable*)self)) { |
426 | return 0; |
427 | } |
428 | Py_CLEAR(self->backward_hooks); |
429 | const auto& tensor = THPVariable_Unpack(self); |
430 | if (tensor.defined()) { |
431 | // Two situations to consider: |
432 | // PyObject -owns-> Tensor |
433 | // unsafeIsBorrowed() is FALSE. We're obligated to look through |
434 | // Tensor to break references. Clearing cdata must induce the |
435 | // destruction of the C++ Tensor. If there were other references |
436 | // to C++ tensor, the Python object would have been resurrected |
437 | // by flipping the ownership. |
438 | // Tensor -owns-> PyObject |
439 | // unsafeIsBorrowed() is TRUE. We're deallocating the PyObject |
440 | // because Tensor asked us to (it's already destructing). |
441 | |
442 | if (!self->cdata.unsafeIsBorrowed() && |
443 | tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( |
444 | getPyInterpreter()) == c10::make_optional((PyObject*)self)) { |
445 | // TODO: empirically, on OS X this assert appears to be untrue |
446 | // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn |
447 | // distributed/rpc/test_process_group_agent.py |
448 | // |
449 | // libc++abi.dylib: terminating with uncaught exception of type |
450 | // c10::Error: |
451 | // !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL |
452 | // ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171, |
453 | // please report a bug to PyTorch. Exception raised from |
454 | // THPVariable_clear at |
455 | // ../torch/csrc/autograd/python_variable.cpp:171 (most recent call |
456 | // first): frame #0: c10::Error::Error(c10::SourceLocation, |
457 | // std::__1::basic_string<char, std::__1::char_traits<char>, |
458 | // std::__1::allocator<char> >) + 98 (0x1158a0442 in libc10.dylib) frame |
459 | // #1: c10::detail::torchCheckFail(char const*, char const*, unsigned |
460 | // int, char const*) + 205 (0x11589ed3d in libc10.dylib) frame #2: |
461 | // c10::detail::torchInternalAssertFail(char const*, char const*, |
462 | // unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9 |
463 | // (0x1141e3f89 in libtorch_python.dylib) frame #3: |
464 | // THPVariable_clear(THPVariable*) + 412 (0x1148a547c in |
465 | // libtorch_python.dylib) frame #4: |
466 | // THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in |
467 | // libtorch_python.dylib) frame #5: (anonymous |
468 | // namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*, |
469 | // _object*) + 53 (0x1148a5ea5 in libtorch_python.dylib) frame #6: |
470 | // c10::TensorImpl::release_resources() + 182 (0x11588c4a6 in |
471 | // libc10.dylib) frame #7: |
472 | // c10::MaybeOwned<at::Tensor>::operator=(c10::MaybeOwned<at::Tensor>&&) |
473 | // + 91 (0x11488c11b in libtorch_python.dylib) frame #8: |
474 | // THPVariable_subclass_dealloc(_object*) + 607 (0x1148a50cf in |
475 | // libtorch_python.dylib) <omitting python frames> frame #47: start + 1 |
476 | // (0x7fff6ffc7cc9 in libdyld.dylib) frame #48: 0x0 + 4 (0x4 in ???) |
477 | // TORCH_INTERNAL_ASSERT(!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()); |
478 | if (auto grad_acc = |
479 | torch::autograd::impl::try_get_grad_accumulator(tensor)) { |
480 | grad_acc->pre_hooks().clear(); |
481 | grad_acc->tensor_pre_hooks().clear(); |
482 | grad_acc->retains_grad_hooks().clear(); |
483 | } |
484 | } |
485 | } |
486 | TORCH_INTERNAL_ASSERT(!isResurrectable((THPVariable*)self)); |
487 | { |
488 | // MapAllocator can take significant time to release large tensors; |
489 | // release the GIL here to avoid impacting main thread perf. |
490 | pybind11::gil_scoped_release no_gil; |
491 | self->cdata = MaybeOwned<Variable>(); |
492 | } |
493 | return 0; |
494 | } |
495 | |
496 | int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) { |
497 | TORCH_INTERNAL_ASSERT( |
498 | false, "Tensor tp_traverse function was not overriden properly" ); |
499 | return 0; |
500 | } |
501 | |
502 | PyObject* THPVariable_pynew( |
503 | PyTypeObject* type, |
504 | PyObject* args, |
505 | PyObject* kwargs); |
506 | |
507 | static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) { |
508 | const auto& var = THPVariable_Unpack(self); |
509 | Py_DECREF(THPVariable_Wrap(var)); |
510 | Py_RETURN_NONE; |
511 | } |
512 | |
513 | static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) { |
514 | HANDLE_TH_ERRORS |
515 | const auto& self = THPVariable_Unpack(self_); |
516 | TORCH_CHECK( |
517 | THPVariable_Check(arg), |
518 | "_view_func expect a single argument that is a Tensor" ); |
519 | const auto& new_base = THPVariable_Unpack(arg); |
520 | |
521 | // Ensure that self is indeed a backward differentiable view |
522 | // If not, we return an undefined Tensor (None) and let the user handle it. |
523 | auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); |
524 | at::Tensor out; |
525 | if (diff_view_meta && diff_view_meta->has_bw_view()) { |
526 | const auto& view_info = diff_view_meta->get_backward_view(); |
527 | // Ensure that the newly provided base is similar to the original base |
528 | if (torch::autograd::utils::has_same_meta(new_base, view_info.base_)) { |
529 | // Do the actual view replay |
530 | if (view_info.has_view_fn()) { |
531 | out = view_info.view_fn()(new_base); |
532 | } else { |
533 | out = new_base.as_strided( |
534 | self.sizes(), self.strides(), self.storage_offset()); |
535 | } |
536 | } |
537 | } |
538 | return THPVariable_Wrap(std::move(out)); |
539 | END_HANDLE_TH_ERRORS |
540 | } |
541 | |
542 | // Instantiates a subclass of self with the same data. |
543 | static PyObject* THPVariable_as_subclass( |
544 | PyObject* _self, |
545 | PyObject* args, |
546 | PyObject* kwargs) { |
547 | HANDLE_TH_ERRORS |
548 | const auto& self = THPVariable_Unpack(_self); |
549 | static PythonArgParser parser({ |
550 | "as_subclass(PyObject* cls)" , |
551 | }); |
552 | ParsedArgs<1> parsed_args{}; |
553 | auto r = parser.parse(_self, args, kwargs, parsed_args); |
554 | PyObject* cls = r.pyobject(0); |
555 | if (!PyType_Check(cls)) { |
556 | throw torch::TypeError( |
557 | "cls must be a type (got %s)" , Py_TYPE(cls)->tp_name); |
558 | } |
559 | return THPVariable_NewWithVar( |
560 | (PyTypeObject*)cls, |
561 | self.alias(), |
562 | c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); |
563 | END_HANDLE_TH_ERRORS |
564 | } |
565 | |
566 | static PyObject* THPVariable_make_subclass( |
567 | PyObject* _ignored, |
568 | PyObject* args, |
569 | PyObject* kwargs) { |
570 | HANDLE_TH_ERRORS |
571 | static PythonArgParser parser({ |
572 | "_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, Device? device_for_backend_keys=None)" , |
573 | }); |
574 | ParsedArgs<7> parsed_args{}; |
575 | auto r = parser.parse(args, kwargs, parsed_args); |
576 | PyObject* cls = r.pyobject(0); |
577 | if (!PyType_Check(cls)) { |
578 | throw torch::TypeError( |
579 | "cls must be a type (got %s)" , Py_TYPE(cls)->tp_name); |
580 | } |
581 | // guard completely turns off torch dispatch modes, doesn't just pop off the |
582 | // stack |
583 | torch_dispatch_mode::StashTorchDispatchStackGuard td_g; |
584 | c10::impl::DisablePythonDispatcher dpd_g; |
585 | auto data = |
586 | r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED) |
587 | // We set `data`'s `allow_tensor_metadata_change` to true here, because we |
588 | // want to allow the following use case for backward compatibility: |
589 | // |
590 | // ```python |
591 | // rnn = torch.nn.RNN(100, 100, 2) |
592 | // # The following calls `torch._cudnn_rnn_flatten_weight(rnn._flat_weights, |
593 | // ...)`, # which changes storage of `rnn`'s weights in-place |
594 | // rnn.flatten_parameters() |
595 | // ``` |
596 | data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true); |
597 | data.set_requires_grad(r.toBool(2)); |
598 | const auto sizes_strides_policy = r.stringViewOptional(3); |
599 | if (sizes_strides_policy.has_value()) { |
600 | data.unsafeGetTensorImpl()->set_python_custom_sizes_strides( |
601 | parseSizesStridesPolicyArgument(*sizes_strides_policy)); |
602 | } |
603 | if (r.toBool(4)) { |
604 | data.unsafeGetTensorImpl()->set_python_custom_device(true); |
605 | } |
606 | if (r.toBool(5)) { |
607 | data.unsafeGetTensorImpl()->set_python_custom_layout(true); |
608 | } |
609 | if (!r.isNone(6)) { |
610 | data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6)); |
611 | } |
612 | |
613 | return THPVariable_NewWithVar( |
614 | (PyTypeObject*)cls, |
615 | std::move(data), |
616 | c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); |
617 | END_HANDLE_TH_ERRORS |
618 | } |
619 | |
620 | static PyObject* THPVariable_make_wrapper_subclass( |
621 | PyObject*, |
622 | PyObject* args, |
623 | PyObject* kwargs) { |
624 | HANDLE_TH_ERRORS |
625 | // NB: pin_memory doesn't actually do anything |
626 | // TODO: strides variant? |
627 | static PythonArgParser parser({ |
628 | "_make_wrapper_subclass(PyObject* cls, IntArrayRef size, *, IntArrayRef? strides=None, " |
629 | "int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, " |
630 | "Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, " |
631 | "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False)" , |
632 | "_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, " |
633 | "SymInt? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, " |
634 | "Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, " |
635 | "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False)" , |
636 | }); |
637 | ParsedArgs<13> parsed_args{}; |
638 | auto r = parser.parse(args, kwargs, parsed_args); |
639 | PyObject* cls = r.pyobject(0); |
640 | |
641 | TORCH_CHECK_TYPE( |
642 | PyType_Check(cls), |
643 | "cls must be a type (got " , |
644 | Py_TYPE(cls)->tp_name, |
645 | ")" ); |
646 | |
647 | // This is an important safety check; without it, the default behavior will be |
648 | // to continue on to the underlying CPU/CUDA kernel advertised by the dispatch |
649 | // key, which will immediately segfault because the data pointer is null. By |
650 | // forcing users to define __torch_dispatch__ we ensure this does not happen |
651 | // TODO: This check is not complete; because the user can disable torch |
652 | // dispatch and then go again, triggering segfault. TBH I'm thinking I want |
653 | // to delete this function entirely |
654 | py::object attr = PyObject_FastGetAttrString(cls, "__torch_dispatch__" ); |
655 | TORCH_CHECK_TYPE( |
656 | attr.ptr() != nullptr && |
657 | attr.ptr() != torch::disabled_torch_dispatch_impl(), |
658 | ((PyTypeObject*)cls)->tp_name, |
659 | " must define __torch_dispatch__" ); |
660 | |
661 | const auto options = TensorOptions() |
662 | .dtype(r.scalartype(5)) |
663 | .device(r.device(7)) |
664 | .layout(r.layoutOptional(6)) |
665 | // NB: long standing issue, requires_grad is not |
666 | // respected here; you have to set it post facto, see |
667 | // https://github.com/pytorch/pytorch/issues/26428 |
668 | // .requires_grad(r.toBool(7)) |
669 | .pinned_memory(r.toBool(8)); |
670 | |
671 | // don't bother releasing GIL here, as we are not allocating any nontrivial |
672 | // data |
673 | // TODO: for_blob produces non-resizable tensors, we might want this to be |
674 | // resizable (have to define a custom allocator in that case) |
675 | Tensor tensor; |
676 | if (r.idx == 0) { |
677 | tensor = at::for_blob(nullptr, r.intlist(1)) |
678 | .strides(r.intlistOptional(2)) |
679 | .storage_offset(r.toInt64Optional(3)) |
680 | .context(nullptr, [](void* ctx) {}) |
681 | .target_device( |
682 | options.device()) // TODO: this shouldn't be necessary if |
683 | // it came from options |
684 | .options(options) |
685 | .make_tensor(); |
686 | |
687 | const auto sizes_strides_policy = r.stringViewOptional(10); |
688 | if (sizes_strides_policy.has_value()) { |
689 | tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides( |
690 | parseSizesStridesPolicyArgument(*sizes_strides_policy)); |
691 | } |
692 | } else { |
693 | AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove. |
694 | tracer::impl::NoTracerDispatchMode tracer_guard{}; |
695 | |
696 | // We shouldn't need storage |
697 | Storage storage{Storage::use_byte_size_t{}, 0, at::DataPtr{}}; |
698 | |
699 | tensor = at::detail::make_tensor<TensorImpl>( |
700 | std::move(storage), options.computeDispatchKey(), options.dtype()); |
701 | |
702 | auto sym_sizes = r.symintlist(1); |
703 | auto sym_strides = r.symintlist(2); |
704 | auto sym_storage_offset = r.toSymIntOptional(3); |
705 | |
706 | TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); |
707 | |
708 | tensor_impl->set_sizes_and_strides( |
709 | sym_sizes, sym_strides, sym_storage_offset.value_or(0)); |
710 | |
711 | const auto sizes_strides_policy = r.stringViewOptional(10); |
712 | if (sizes_strides_policy.has_value()) { |
713 | TORCH_CHECK( |
714 | false, |
715 | "Setting sizes_strides_policy isn't supported for this overload" ) |
716 | } |
717 | } |
718 | |
719 | tensor.set_requires_grad(r.toBool(9)); |
720 | |
721 | if (r.toBool(11)) { |
722 | tensor.unsafeGetTensorImpl()->set_python_custom_device(true); |
723 | } |
724 | if (r.toBool(12)) { |
725 | tensor.unsafeGetTensorImpl()->set_python_custom_layout(true); |
726 | } |
727 | |
728 | return THPVariable_NewWithVar( |
729 | (PyTypeObject*)cls, |
730 | std::move(tensor), |
731 | c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); |
732 | END_HANDLE_TH_ERRORS |
733 | } |
734 | |
735 | typedef PyObject* (*getter)(PyObject*, void*); |
736 | typedef int (*setter)(PyObject*, PyObject*, void*); |
737 | |
738 | PyObject* THPVariable_get_python_dispatch(THPVariable* self, void* unused) { |
739 | HANDLE_TH_ERRORS |
740 | const auto& var = THPVariable_Unpack(self); |
741 | return torch::autograd::utils::wrap( |
742 | var.unsafeGetTensorImpl()->is_python_dispatch()); |
743 | END_HANDLE_TH_ERRORS |
744 | } |
745 | |
746 | // CRTP base class to implement the python bindings for a Tensor property in |
747 | // PyTorch A class that implements a property is expected to have: |
748 | // - static constexpr const char* name; |
749 | // - This variable should hold the Python name of the property |
750 | // - static Tensor fn(const Tensor&); |
751 | // - This function calls the relevant ATen on the tensor |
752 | template <typename T> |
753 | struct GetterBase { |
754 | static PyObject* getter(THPVariable* self, void* /*unused*/) { |
755 | HANDLE_TH_ERRORS |
756 | if (check_has_torch_function((PyObject*)self)) { |
757 | return handle_torch_function_getter(self, T::name); |
758 | } |
759 | return THPVariable_Wrap(T::fn(THPVariable_Unpack(self))); |
760 | END_HANDLE_TH_ERRORS |
761 | } |
762 | }; |
763 | |
764 | struct PropertyT : GetterBase<PropertyT> { |
765 | static constexpr const char* name = "T" ; |
766 | static Tensor fn(const Tensor& t) { |
767 | return t.numpy_T(); |
768 | } |
769 | }; |
770 | |
771 | struct PropertyH : GetterBase<PropertyH> { |
772 | static constexpr const char* name = "H" ; |
773 | static Tensor fn(const Tensor& t) { |
774 | return t.matrix_H(); |
775 | } |
776 | }; |
777 | |
778 | struct PropertymT : GetterBase<PropertymT> { |
779 | static constexpr const char* name = "mT" ; |
780 | static Tensor fn(const Tensor& t) { |
781 | return t.mT(); |
782 | } |
783 | }; |
784 | |
785 | struct PropertymH : GetterBase<PropertymH> { |
786 | static constexpr const char* name = "mH" ; |
787 | static Tensor fn(const Tensor& t) { |
788 | return t.mH(); |
789 | } |
790 | }; |
791 | |
792 | struct PropertyData : GetterBase<PropertyData> { |
793 | static constexpr const char* name = "data" ; |
794 | static Tensor fn(const Tensor& t) { |
795 | return t.variable_data(); |
796 | } |
797 | }; |
798 | |
799 | struct PropertyGrad : GetterBase<PropertyGrad> { |
800 | static constexpr const char* name = "grad" ; |
801 | static Tensor fn(const Tensor& t) { |
802 | return t.grad(); |
803 | } |
804 | }; |
805 | |
806 | struct PropertyReal : GetterBase<PropertyReal> { |
807 | static constexpr const char* name = "real" ; |
808 | static Tensor fn(const Tensor& t) { |
809 | return at::real(t); |
810 | } |
811 | }; |
812 | |
813 | struct PropertyImag : GetterBase<PropertyImag> { |
814 | static constexpr const char* name = "imag" ; |
815 | static Tensor fn(const Tensor& t) { |
816 | return at::imag(t); |
817 | } |
818 | }; |
819 | |
820 | PyObject* THPVariable_get_cdata(THPVariable* self, void* unused) { |
821 | HANDLE_TH_ERRORS |
822 | if (check_has_torch_function((PyObject*)self)) { |
823 | return handle_torch_function_getter(self, "_cdata" ); |
824 | } |
825 | const auto& var = THPVariable_Unpack(self); |
826 | return PyLong_FromVoidPtr(var.unsafeGetTensorImpl()); |
827 | END_HANDLE_TH_ERRORS |
828 | } |
829 | |
830 | PyObject* THPVariable_get_version(THPVariable* self, void* unused) { |
831 | HANDLE_TH_ERRORS |
832 | if (check_has_torch_function((PyObject*)self)) { |
833 | return handle_torch_function_getter(self, "_version" ); |
834 | } |
835 | const auto& var = THPVariable_Unpack(self); |
836 | return PyInt_FromLong(var._version()); |
837 | END_HANDLE_TH_ERRORS |
838 | } |
839 | |
840 | PyObject* THPVariable_get_grad_fn(THPVariable* self, void* unused) { |
841 | HANDLE_TH_ERRORS |
842 | if (check_has_torch_function((PyObject*)self)) { |
843 | return handle_torch_function_getter(self, "grad_fn" ); |
844 | } |
845 | const auto& var = THPVariable_Unpack(self); |
846 | if (!var.grad_fn()) { |
847 | Py_RETURN_NONE; |
848 | } |
849 | return functionToPyObject(var.grad_fn()); |
850 | END_HANDLE_TH_ERRORS |
851 | } |
852 | |
853 | static int THPVariable_set_grad_fn( |
854 | THPVariable* self, |
855 | PyObject* obj, |
856 | void* unused) { |
857 | HANDLE_TH_ERRORS |
858 | if (check_has_torch_function((PyObject*)self)) { |
859 | return handle_torch_function_setter(self, "_grad_fn" , obj); |
860 | } |
861 | THPUtils_assertRet( |
862 | -1, obj, "Deletion of _grad_fn not allowed. Detach tensor instead!" ); |
863 | THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None" ); |
864 | THPVariable_Unpack(self).detach_(); |
865 | return 0; |
866 | END_HANDLE_TH_ERRORS_RET(-1) |
867 | } |
868 | |
869 | static PyObject* THPVariable_is_leaf(THPVariable* self, void* unused) { |
870 | HANDLE_TH_ERRORS |
871 | if (check_has_torch_function((PyObject*)self)) { |
872 | return handle_torch_function_getter(self, "is_leaf" ); |
873 | } |
874 | return PyBool_FromLong(!THPVariable_Unpack(self).grad_fn()); |
875 | END_HANDLE_TH_ERRORS |
876 | } |
877 | |
878 | int THPVariable_set_data(THPVariable* self, PyObject* data, void* unused) { |
879 | HANDLE_TH_ERRORS |
880 | if (check_has_torch_function((PyObject*)self)) { |
881 | return handle_torch_function_setter(self, "data" , data); |
882 | } |
883 | THPUtils_assertRet( |
884 | -1, data, "Deleting tensor data is not allowed. Delete tensor instead!" ); |
885 | if (!THPVariable_Check(data)) { |
886 | throw torch::TypeError( |
887 | "Variable data has to be a tensor, but got %s" , Py_TYPE(data)->tp_name); |
888 | } |
889 | |
890 | THPVariable_Unpack(self).set_data(THPVariable_Unpack(data)); |
891 | return 0; |
892 | END_HANDLE_TH_ERRORS_RET(-1) |
893 | } |
894 | |
895 | int THPVariable_set_grad(THPVariable* self, PyObject* py_grad, void* unused) { |
896 | HANDLE_TH_ERRORS |
897 | if (check_has_torch_function((PyObject*)self)) { |
898 | return handle_torch_function_setter(self, "grad" , py_grad); |
899 | } |
900 | const auto& var = THPVariable_Unpack(self); |
901 | if (!py_grad || py_grad == Py_None) { |
902 | var.mutable_grad().reset(); |
903 | return 0; |
904 | } |
905 | |
906 | TORCH_CHECK_TYPE( |
907 | THPVariable_Check(py_grad), |
908 | "assigned grad expected to be a Tensor or None but got grad of type" , |
909 | THPUtils_typename(py_grad)); |
910 | THPUtils_assertRet( |
911 | -1, |
912 | self != (THPVariable*)py_grad, |
913 | "can't assign Variable as its own grad" ); |
914 | |
915 | const auto& grad = THPVariable_Unpack(py_grad); |
916 | bool gradIsSparse = |
917 | (var.dtype() == grad.dtype() && |
918 | var.device().type() == grad.device().type() && grad.layout() == kSparse); |
919 | THPUtils_assertRet( |
920 | -1, |
921 | grad.options().type_equal(var.options()) || gradIsSparse, |
922 | "assigned grad has data of a different type" ); |
923 | if (var.is_cuda()) { |
924 | THPUtils_assertRet( |
925 | -1, |
926 | grad.get_device() == var.get_device(), |
927 | "assigned grad has data located on a different device" ); |
928 | } |
929 | THPUtils_assertRet( |
930 | -1, |
931 | grad.sym_sizes().equals(var.sym_sizes()), |
932 | "assigned grad has data of a different size" ); |
933 | |
934 | var.mutable_grad() = grad; |
935 | return 0; |
936 | END_HANDLE_TH_ERRORS_RET(-1) |
937 | } |
938 | |
939 | PyObject* THPVariable_get_volatile(THPVariable* self, void* unused) { |
940 | HANDLE_TH_ERRORS |
941 | if (check_has_torch_function((PyObject*)self)) { |
942 | return handle_torch_function_getter(self, "volatile" ); |
943 | } |
944 | const char* msg = "volatile was removed (Variable.volatile is always False)" ; |
945 | auto r = PyErr_WarnEx(PyExc_UserWarning, msg, 1); |
946 | if (r != 0) |
947 | throw python_error(); |
948 | Py_RETURN_FALSE; |
949 | END_HANDLE_TH_ERRORS |
950 | } |
951 | |
952 | int THPVariable_set_volatile(THPVariable* self, PyObject* obj, void* unused) { |
953 | HANDLE_TH_ERRORS |
954 | if (check_has_torch_function((PyObject*)self)) { |
955 | return handle_torch_function_setter(self, "volatile" , obj); |
956 | } |
957 | auto r = PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1); |
958 | if (r != 0) |
959 | throw python_error(); |
960 | return 0; |
961 | END_HANDLE_TH_ERRORS_RET(-1) |
962 | } |
963 | |
964 | PyObject* THPVariable_get_output_nr(THPVariable* self, void* unused) { |
965 | HANDLE_TH_ERRORS |
966 | if (check_has_torch_function((PyObject*)self)) { |
967 | return handle_torch_function_getter(self, "output_nr" ); |
968 | } |
969 | const auto output_nr = |
970 | static_cast<long>(THPVariable_Unpack(self).output_nr()); |
971 | return PyInt_FromLong(output_nr); |
972 | END_HANDLE_TH_ERRORS |
973 | } |
974 | |
975 | PyObject* THPVariable_get_requires_grad(THPVariable* self, void* unused) { |
976 | HANDLE_TH_ERRORS |
977 | if (check_has_torch_function((PyObject*)self)) { |
978 | return handle_torch_function_getter(self, "requires_grad" ); |
979 | } |
980 | if (THPVariable_Unpack(self).requires_grad()) { |
981 | Py_RETURN_TRUE; |
982 | } else { |
983 | Py_RETURN_FALSE; |
984 | } |
985 | END_HANDLE_TH_ERRORS |
986 | } |
987 | |
988 | PyObject* THPVariable_retains_grad(THPVariable* self, void* unused) { |
989 | HANDLE_TH_ERRORS |
990 | if (check_has_torch_function((PyObject*)self)) { |
991 | return handle_torch_function_getter(self, "retains_grad" ); |
992 | } |
993 | if (THPVariable_Unpack(self).retains_grad()) { |
994 | Py_RETURN_TRUE; |
995 | } else { |
996 | Py_RETURN_FALSE; |
997 | } |
998 | END_HANDLE_TH_ERRORS |
999 | } |
1000 | |
1001 | PyObject* THPVariable_get_ndim(THPVariable* self, void* unused) { |
1002 | HANDLE_TH_ERRORS |
1003 | if (check_has_torch_function((PyObject*)self)) { |
1004 | return handle_torch_function_getter(self, "ndim" ); |
1005 | } |
1006 | return PyInt_FromLong(THPVariable_Unpack(self).dim()); |
1007 | END_HANDLE_TH_ERRORS |
1008 | } |
1009 | |
1010 | PyObject* THPVariable_get_names(PyObject* self, void* unused) { |
1011 | HANDLE_TH_ERRORS |
1012 | if (check_has_torch_function(self)) { |
1013 | return handle_torch_function_getter((THPVariable*)self, "names" ); |
1014 | } |
1015 | // The long-term plan is to return a list of (python) torch.Dimname. |
1016 | // However, for now, return a list of string. |
1017 | const auto& tensor = THPVariable_Unpack(self); |
1018 | size_t size = tensor.dim(); |
1019 | THPObjectPtr tuple(PyTuple_New(size)); |
1020 | if (!tuple) |
1021 | throw python_error(); |
1022 | |
1023 | const auto dimnames = tensor.names(); |
1024 | for (const auto i : c10::irange(size)) { |
1025 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1026 | PyObject* str; |
1027 | if (dimnames[i].type() == at::NameType::WILDCARD) { |
1028 | // PyTuple_SET_ITEM steals a reference to the object. When the tuple is |
1029 | // deallocated, it'll decrement the refcount on Py_None, which is bad. |
1030 | // To avoid this, we "create" a new reference to Py_None by increasing |
1031 | // the refcount. |
1032 | // Sources: |
1033 | // - https://docs.python.org/3/c-api/tuple.html#c.PyTuple_SetItem |
1034 | // - |
1035 | // https://stackoverflow.com/questions/16400600/how-to-return-a-tuple-containing-a-none-value-from-the-c-api |
1036 | Py_INCREF(Py_None); |
1037 | str = Py_None; |
1038 | } else { |
1039 | str = THPUtils_packString(dimnames[i].symbol().toUnqualString()); |
1040 | if (!str) |
1041 | throw python_error(); |
1042 | } |
1043 | PyTuple_SET_ITEM(tuple.get(), i, str); |
1044 | } |
1045 | return tuple.release(); |
1046 | END_HANDLE_TH_ERRORS |
1047 | } |
1048 | |
1049 | int THPVariable_set_names(PyObject* self, PyObject* names, void* unused) { |
1050 | HANDLE_TH_ERRORS |
1051 | if (check_has_torch_function(self)) { |
1052 | return handle_torch_function_setter((THPVariable*)self, "names" , names); |
1053 | } |
1054 | const auto& var = THPVariable_Unpack(self); |
1055 | if (names == Py_None) { |
1056 | at::internal_set_names_inplace(var, at::nullopt); |
1057 | } else { |
1058 | THPUtils_assertRet( |
1059 | -1, |
1060 | THPUtils_checkDimnameList(names), |
1061 | "names must either be None or a tuple of dim names" ); |
1062 | at::internal_set_names_inplace(var, torch::parseDimnameList(names)); |
1063 | } |
1064 | return 0; |
1065 | END_HANDLE_TH_ERRORS_RET(-1) |
1066 | } |
1067 | |
1068 | int THPVariable_set_requires_grad( |
1069 | THPVariable* self, |
1070 | PyObject* obj, |
1071 | void* unused) { |
1072 | HANDLE_TH_ERRORS |
1073 | if (check_has_torch_function((PyObject*)self)) { |
1074 | return handle_torch_function_setter(self, "requires_grad" , obj); |
1075 | } |
1076 | THPUtils_assertRet( |
1077 | -1, obj && PyBool_Check(obj), "requires_grad must be a bool" ); |
1078 | const auto& var = THPVariable_Unpack(self); |
1079 | auto requires_grad = (obj == Py_True); |
1080 | if (!var.is_leaf()) { |
1081 | THPUtils_setError( |
1082 | autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str()); |
1083 | return -1; |
1084 | } |
1085 | if (requires_grad && |
1086 | !isDifferentiableType(at::typeMetaToScalarType((var.dtype())))) { |
1087 | THPUtils_setError( |
1088 | "only Tensors of floating point and complex dtype can require gradients" ); |
1089 | return -1; |
1090 | } |
1091 | var.set_requires_grad(requires_grad); |
1092 | return 0; |
1093 | END_HANDLE_TH_ERRORS_RET(-1) |
1094 | } |
1095 | |
1096 | PyObject* THPVariable_get_name(THPVariable* self, void* unused) { |
1097 | if (check_has_torch_function((PyObject*)self)) { |
1098 | HANDLE_TH_ERRORS |
1099 | return handle_torch_function_getter(self, "name" ); |
1100 | END_HANDLE_TH_ERRORS |
1101 | } |
1102 | const auto& tensor = THPVariable_Unpack(self); |
1103 | if (tensor.name().empty()) |
1104 | Py_RETURN_NONE; |
1105 | return THPUtils_packString(tensor.name().c_str()); |
1106 | } |
1107 | |
1108 | PyObject* THPVariable_get_backwards_hooks(THPVariable* self, void* unused) { |
1109 | HANDLE_TH_ERRORS |
1110 | if (check_has_torch_function((PyObject*)self)) { |
1111 | return handle_torch_function_getter(self, "_backward_hooks" ); |
1112 | } |
1113 | if (self->backward_hooks) { |
1114 | Py_INCREF(self->backward_hooks); |
1115 | return self->backward_hooks; |
1116 | } |
1117 | Py_RETURN_NONE; |
1118 | END_HANDLE_TH_ERRORS |
1119 | } |
1120 | |
1121 | int THPVariable_set_backwards_hooks( |
1122 | THPVariable* self, |
1123 | PyObject* obj, |
1124 | void* unused) { |
1125 | HANDLE_TH_ERRORS |
1126 | if (check_has_torch_function((PyObject*)self)) { |
1127 | return handle_torch_function_setter(self, "_backward_hooks" , obj); |
1128 | } |
1129 | THPUtils_assertRet(-1, obj, "Deletion of _backwards_hooks not allowed!" ); |
1130 | if (obj == Py_None) { |
1131 | obj = nullptr; |
1132 | } |
1133 | Py_XINCREF(obj); |
1134 | Py_XDECREF(self->backward_hooks); |
1135 | self->backward_hooks = obj; |
1136 | const auto& tensor = THPVariable_Unpack(self); |
1137 | torch::autograd::impl::clear_hooks(tensor); |
1138 | if (obj) { |
1139 | torch::autograd::impl::add_hook( |
1140 | tensor, std::make_unique<PyFunctionTensorPreHook>(obj, 0)); |
1141 | } |
1142 | return 0; |
1143 | END_HANDLE_TH_ERRORS_RET(-1) |
1144 | } |
1145 | |
1146 | PyObject* THPVariable_get_base(THPVariable* self, void* unused) { |
1147 | HANDLE_TH_ERRORS |
1148 | if (check_has_torch_function((PyObject*)self)) { |
1149 | return handle_torch_function_getter(self, "_base" ); |
1150 | } |
1151 | const auto& tensor = THPVariable_Unpack(self); |
1152 | if (tensor.is_view()) { |
1153 | return THPVariable_Wrap(tensor._base()); |
1154 | } |
1155 | Py_RETURN_NONE; |
1156 | END_HANDLE_TH_ERRORS |
1157 | } |
1158 | |
1159 | PyObject* THPVariable_get_shape(THPVariable* self, void* unused) { |
1160 | HANDLE_TH_ERRORS |
1161 | if (check_has_torch_function((PyObject*)self)) { |
1162 | return handle_torch_function_getter(self, "shape" ); |
1163 | } |
1164 | return THPSize_NewFromSymSizes(THPVariable_Unpack(self)); |
1165 | END_HANDLE_TH_ERRORS |
1166 | } |
1167 | |
1168 | PyObject* THPVariable_is_cpu(THPVariable* self, void* unused) { |
1169 | HANDLE_TH_ERRORS |
1170 | if (check_has_torch_function((PyObject*)self)) { |
1171 | return handle_torch_function_getter(self, "is_cpu" ); |
1172 | } |
1173 | auto& self_ = THPVariable_Unpack(self); |
1174 | return torch::autograd::utils::wrap(self_.is_cpu()); |
1175 | END_HANDLE_TH_ERRORS |
1176 | } |
1177 | |
1178 | PyObject* THPVariable_is_cuda(THPVariable* self, void* unused) { |
1179 | HANDLE_TH_ERRORS |
1180 | if (check_has_torch_function((PyObject*)self)) { |
1181 | return handle_torch_function_getter(self, "is_cuda" ); |
1182 | } |
1183 | auto& self_ = THPVariable_Unpack(self); |
1184 | return torch::autograd::utils::wrap(self_.is_cuda()); |
1185 | END_HANDLE_TH_ERRORS |
1186 | } |
1187 | |
1188 | PyObject* THPVariable_is_ipu(THPVariable* self, void* unused) { |
1189 | HANDLE_TH_ERRORS |
1190 | if (check_has_torch_function((PyObject*)self)) { |
1191 | return handle_torch_function_getter(self, "is_ipu" ); |
1192 | } |
1193 | auto& self_ = THPVariable_Unpack(self); |
1194 | return torch::autograd::utils::wrap(self_.is_ipu()); |
1195 | END_HANDLE_TH_ERRORS |
1196 | } |
1197 | |
1198 | PyObject* THPVariable_is_xpu(THPVariable* self, void* unused) { |
1199 | HANDLE_TH_ERRORS |
1200 | if (check_has_torch_function((PyObject*)self)) { |
1201 | return handle_torch_function_getter(self, "is_xpu" ); |
1202 | } |
1203 | auto& self_ = THPVariable_Unpack(self); |
1204 | return torch::autograd::utils::wrap(self_.is_xpu()); |
1205 | END_HANDLE_TH_ERRORS |
1206 | } |
1207 | |
1208 | PyObject* THPVariable_is_sparse(THPVariable* self, void* unused) { |
1209 | HANDLE_TH_ERRORS |
1210 | if (check_has_torch_function((PyObject*)self)) { |
1211 | return handle_torch_function_getter(self, "is_sparse" ); |
1212 | } |
1213 | auto& self_ = THPVariable_Unpack(self); |
1214 | return torch::autograd::utils::wrap(self_.is_sparse()); |
1215 | END_HANDLE_TH_ERRORS |
1216 | } |
1217 | |
1218 | PyObject* THPVariable_is_sparse_csr(THPVariable* self, void* unused) { |
1219 | HANDLE_TH_ERRORS |
1220 | if (check_has_torch_function((PyObject*)self)) { |
1221 | return handle_torch_function_getter(self, "is_sparse_csr" ); |
1222 | } |
1223 | auto& self_ = THPVariable_Unpack(self); |
1224 | return torch::autograd::utils::wrap(self_.is_sparse_csr()); |
1225 | END_HANDLE_TH_ERRORS |
1226 | } |
1227 | |
1228 | PyObject* THPVariable_is_mkldnn(THPVariable* self, void* unused) { |
1229 | HANDLE_TH_ERRORS |
1230 | if (check_has_torch_function((PyObject*)self)) { |
1231 | return handle_torch_function_getter(self, "is_mkldnn" ); |
1232 | } |
1233 | auto& self_ = THPVariable_Unpack(self); |
1234 | return torch::autograd::utils::wrap(self_.is_mkldnn()); |
1235 | END_HANDLE_TH_ERRORS |
1236 | } |
1237 | |
1238 | PyObject* THPVariable_is_mps(THPVariable* self, void* unused) { |
1239 | HANDLE_TH_ERRORS |
1240 | if (check_has_torch_function((PyObject*)self)) { |
1241 | return handle_torch_function_getter(self, "is_mps" ); |
1242 | } |
1243 | auto& self_ = THPVariable_Unpack(self); |
1244 | return torch::autograd::utils::wrap(self_.is_mps()); |
1245 | END_HANDLE_TH_ERRORS |
1246 | } |
1247 | |
1248 | PyObject* THPVariable_is_ort(THPVariable* self, void* unused) { |
1249 | HANDLE_TH_ERRORS |
1250 | if (check_has_torch_function((PyObject*)self)) { |
1251 | return handle_torch_function_getter(self, "is_ort" ); |
1252 | } |
1253 | auto& self_ = THPVariable_Unpack(self); |
1254 | return torch::autograd::utils::wrap(self_.is_ort()); |
1255 | END_HANDLE_TH_ERRORS |
1256 | } |
1257 | |
1258 | PyObject* THPVariable_is_vulkan(THPVariable* self, void* unused) { |
1259 | HANDLE_TH_ERRORS |
1260 | if (check_has_torch_function((PyObject*)self)) { |
1261 | return handle_torch_function_getter(self, "is_vulkan" ); |
1262 | } |
1263 | auto& self_ = THPVariable_Unpack(self); |
1264 | return torch::autograd::utils::wrap(self_.is_vulkan()); |
1265 | END_HANDLE_TH_ERRORS |
1266 | } |
1267 | |
1268 | PyObject* THPVariable_is_quantized(THPVariable* self, void* unused) { |
1269 | HANDLE_TH_ERRORS |
1270 | if (check_has_torch_function((PyObject*)self)) { |
1271 | return handle_torch_function_getter(self, "is_quantized" ); |
1272 | } |
1273 | auto& self_ = THPVariable_Unpack(self); |
1274 | return torch::autograd::utils::wrap(self_.is_quantized()); |
1275 | END_HANDLE_TH_ERRORS |
1276 | } |
1277 | |
1278 | PyObject* THPVariable_is_meta(THPVariable* self, void* unused) { |
1279 | HANDLE_TH_ERRORS |
1280 | if (check_has_torch_function((PyObject*)self)) { |
1281 | return handle_torch_function_getter(self, "is_meta" ); |
1282 | } |
1283 | auto& self_ = THPVariable_Unpack(self); |
1284 | return torch::autograd::utils::wrap(self_.is_meta()); |
1285 | END_HANDLE_TH_ERRORS |
1286 | } |
1287 | |
1288 | PyObject* THPVariable_is_complex(THPVariable* self, void* unused) { |
1289 | HANDLE_TH_ERRORS |
1290 | if (check_has_torch_function((PyObject*)self)) { |
1291 | return handle_torch_function_getter(self, "is_complex" ); |
1292 | } |
1293 | auto& self_ = THPVariable_Unpack(self); |
1294 | return torch::autograd::utils::wrap(self_.is_complex()); |
1295 | END_HANDLE_TH_ERRORS |
1296 | } |
1297 | |
1298 | PyObject* THPVariable_is_nested(THPVariable* self, void* unused) { |
1299 | HANDLE_TH_ERRORS |
1300 | if (check_has_torch_function((PyObject*)self)) { |
1301 | return handle_torch_function_getter(self, "is_nested" ); |
1302 | } |
1303 | auto& self_ = THPVariable_Unpack(self); |
1304 | return torch::autograd::utils::wrap(self_.is_nested()); |
1305 | END_HANDLE_TH_ERRORS |
1306 | } |
1307 | |
1308 | PyObject* THPVariable_has_symbolic_sizes_strides( |
1309 | THPVariable* self, |
1310 | void* unused) { |
1311 | HANDLE_TH_ERRORS |
1312 | auto& self_ = THPVariable_Unpack(self); |
1313 | return torch::autograd::utils::wrap( |
1314 | self_.unsafeGetTensorImpl()->has_symbolic_sizes_strides()); |
1315 | END_HANDLE_TH_ERRORS |
1316 | } |
1317 | |
1318 | static PyObject* THPVariable_dtype(THPVariable* self, void* unused) { |
1319 | HANDLE_TH_ERRORS |
1320 | if (check_has_torch_function((PyObject*)self)) { |
1321 | return handle_torch_function_getter(self, "dtype" ); |
1322 | } |
1323 | auto& self_ = THPVariable_Unpack(self); |
1324 | return torch::autograd::utils::wrap(torch::getTHPDtype(self_.scalar_type())); |
1325 | END_HANDLE_TH_ERRORS |
1326 | } |
1327 | |
1328 | static PyObject* THPVariable_layout(THPVariable* self, void* unused) { |
1329 | HANDLE_TH_ERRORS |
1330 | if (check_has_torch_function((PyObject*)self)) { |
1331 | return handle_torch_function_getter(self, "layout" ); |
1332 | } |
1333 | auto& self_ = THPVariable_Unpack(self); |
1334 | return torch::autograd::utils::wrap(torch::getTHPLayout(self_.layout())); |
1335 | END_HANDLE_TH_ERRORS |
1336 | } |
1337 | |
1338 | static PyObject* THPVariable_device(THPVariable* self, void* unused) { |
1339 | HANDLE_TH_ERRORS |
1340 | if (check_has_torch_function((PyObject*)self)) { |
1341 | return handle_torch_function_getter(self, "device" ); |
1342 | } |
1343 | return THPDevice_New(THPVariable_Unpack(self).device()); |
1344 | END_HANDLE_TH_ERRORS |
1345 | } |
1346 | |
1347 | int THPVariable_set_real(PyObject* self, PyObject* real, void* unused) { |
1348 | HANDLE_TH_ERRORS |
1349 | auto& self_ = THPVariable_Unpack(self); |
1350 | auto self_real = at::real(self_); |
1351 | auto real_ = valueToTensor(self_real.options(), real, self_real.device()); |
1352 | { |
1353 | pybind11::gil_scoped_release no_gil; |
1354 | self_real.copy_(real_); |
1355 | return 0; |
1356 | } |
1357 | END_HANDLE_TH_ERRORS_RET(-1) |
1358 | } |
1359 | |
1360 | int THPVariable_set_imag(PyObject* self, PyObject* imag, void* unused) { |
1361 | HANDLE_TH_ERRORS |
1362 | auto& self_ = THPVariable_Unpack(self); |
1363 | auto self_imag = at::imag(self_); |
1364 | auto imag_ = valueToTensor(self_imag.options(), imag, self_imag.device()); |
1365 | { |
1366 | pybind11::gil_scoped_release no_gil; |
1367 | self_imag.copy_(imag_); |
1368 | return 0; |
1369 | } |
1370 | END_HANDLE_TH_ERRORS_RET(-1) |
1371 | } |
1372 | |
1373 | // properties are registered here because we are currently only able to bind |
1374 | // them manually. TODO: make declarable in native_functions |
1375 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
1376 | static struct PyGetSetDef THPVariable_properties[] = { |
1377 | {"_python_dispatch" , |
1378 | (getter)THPVariable_get_python_dispatch, |
1379 | nullptr, |
1380 | nullptr, |
1381 | nullptr}, |
1382 | {"T" , (getter)PropertyT::getter, nullptr, nullptr, nullptr}, |
1383 | {"H" , (getter)PropertyH::getter, nullptr, nullptr, nullptr}, |
1384 | {"mT" , (getter)PropertymT::getter, nullptr, nullptr, nullptr}, |
1385 | {"mH" , (getter)PropertymH::getter, nullptr, nullptr, nullptr}, |
1386 | {"_cdata" , (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr}, |
1387 | {"_version" , (getter)THPVariable_get_version, nullptr, nullptr, nullptr}, |
1388 | {"grad_fn" , (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr}, |
1389 | {"_grad_fn" , |
1390 | (getter)THPVariable_get_grad_fn, |
1391 | (setter)THPVariable_set_grad_fn, |
1392 | nullptr, |
1393 | nullptr}, |
1394 | {"is_leaf" , (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr}, |
1395 | {"retains_grad" , |
1396 | (getter)THPVariable_retains_grad, |
1397 | nullptr, |
1398 | nullptr, |
1399 | nullptr}, |
1400 | {"data" , |
1401 | (getter)PropertyData::getter, |
1402 | (setter)THPVariable_set_data, |
1403 | nullptr, |
1404 | nullptr}, |
1405 | {"_grad" , |
1406 | (getter)PropertyGrad::getter, |
1407 | (setter)THPVariable_set_grad, |
1408 | nullptr, |
1409 | nullptr}, // Allows the python class to override .grad |
1410 | {"grad" , |
1411 | (getter)PropertyGrad::getter, |
1412 | (setter)THPVariable_set_grad, |
1413 | nullptr, |
1414 | nullptr}, |
1415 | {"_base" , (getter)THPVariable_get_base, nullptr, nullptr, nullptr}, |
1416 | {"volatile" , |
1417 | (getter)THPVariable_get_volatile, |
1418 | (setter)THPVariable_set_volatile, |
1419 | nullptr, |
1420 | nullptr}, |
1421 | {"output_nr" , (getter)THPVariable_get_output_nr, nullptr, nullptr, nullptr}, |
1422 | {"requires_grad" , |
1423 | (getter)THPVariable_get_requires_grad, |
1424 | (setter)THPVariable_set_requires_grad, |
1425 | nullptr, |
1426 | nullptr}, |
1427 | {"_backward_hooks" , |
1428 | (getter)THPVariable_get_backwards_hooks, |
1429 | (setter)THPVariable_set_backwards_hooks, |
1430 | nullptr, |
1431 | nullptr}, |
1432 | {"name" , (getter)THPVariable_get_name, nullptr, nullptr, nullptr}, |
1433 | {"shape" , (getter)THPVariable_get_shape, nullptr, nullptr, nullptr}, |
1434 | {"is_cuda" , (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr}, |
1435 | {"is_cpu" , (getter)THPVariable_is_cpu, nullptr, nullptr, nullptr}, |
1436 | {"is_xpu" , (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr}, |
1437 | {"is_ipu" , (getter)THPVariable_is_ipu, nullptr, nullptr, nullptr}, |
1438 | {"is_sparse" , (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr}, |
1439 | {"is_sparse_csr" , |
1440 | (getter)THPVariable_is_sparse_csr, |
1441 | nullptr, |
1442 | nullptr, |
1443 | nullptr}, |
1444 | {"is_mkldnn" , (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr}, |
1445 | {"is_mps" , (getter)THPVariable_is_mps, nullptr, nullptr, nullptr}, |
1446 | {"is_ort" , (getter)THPVariable_is_ort, nullptr, nullptr, nullptr}, |
1447 | {"is_vulkan" , (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr}, |
1448 | {"is_complex" , (getter)THPVariable_is_complex, nullptr, nullptr, nullptr}, |
1449 | {"is_quantized" , |
1450 | (getter)THPVariable_is_quantized, |
1451 | nullptr, |
1452 | nullptr, |
1453 | nullptr}, |
1454 | {"is_meta" , (getter)THPVariable_is_meta, nullptr, nullptr, nullptr}, |
1455 | {"is_nested" , (getter)THPVariable_is_nested, nullptr, nullptr, nullptr}, |
1456 | {"_has_symbolic_sizes_strides" , |
1457 | (getter)THPVariable_has_symbolic_sizes_strides, |
1458 | nullptr, |
1459 | nullptr, |
1460 | nullptr}, |
1461 | {"dtype" , (getter)THPVariable_dtype, nullptr, nullptr, nullptr}, |
1462 | {"layout" , (getter)THPVariable_layout, nullptr, nullptr, nullptr}, |
1463 | {"device" , (getter)THPVariable_device, nullptr, nullptr, nullptr}, |
1464 | {"ndim" , (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr}, |
1465 | {"names" , |
1466 | (getter)THPVariable_get_names, |
1467 | (setter)THPVariable_set_names, |
1468 | nullptr, |
1469 | nullptr}, |
1470 | {"real" , |
1471 | (getter)PropertyReal::getter, |
1472 | (setter)THPVariable_set_real, |
1473 | nullptr, |
1474 | nullptr}, |
1475 | {"imag" , |
1476 | (getter)PropertyImag::getter, |
1477 | (setter)THPVariable_set_imag, |
1478 | nullptr, |
1479 | nullptr}, |
1480 | {nullptr}}; |
1481 | |
1482 | static PyMappingMethods THPVariable_as_mapping = { |
1483 | THPVariable_length, |
1484 | THPVariable_getitem, |
1485 | THPVariable_setitem, |
1486 | }; |
1487 | |
1488 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
1489 | static PyMethodDef [] = { |
1490 | {"as_subclass" , |
1491 | castPyCFunctionWithKeywords(THPVariable_as_subclass), |
1492 | METH_VARARGS | METH_KEYWORDS, |
1493 | nullptr}, |
1494 | {"_make_subclass" , |
1495 | castPyCFunctionWithKeywords(THPVariable_make_subclass), |
1496 | METH_STATIC | METH_VARARGS | METH_KEYWORDS, |
1497 | nullptr}, |
1498 | {"_make_wrapper_subclass" , |
1499 | castPyCFunctionWithKeywords(THPVariable_make_wrapper_subclass), |
1500 | METH_STATIC | METH_VARARGS | METH_KEYWORDS, |
1501 | nullptr}, |
1502 | {"_fix_weakref" , THPVariable_fix_weakref, METH_NOARGS, nullptr}, |
1503 | {"_view_func" , THPVariable_view_func, METH_O, nullptr}, |
1504 | {nullptr}}; |
1505 | |
1506 | struct THPVariableMeta { |
1507 | PyHeapTypeObject base; |
1508 | }; |
1509 | |
1510 | int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs); |
1511 | |
1512 | PyTypeObject THPVariableMetaType = { |
1513 | PyVarObject_HEAD_INIT( |
1514 | DEFERRED_ADDRESS(&PyType_Type), |
1515 | 0) "torch._C._TensorMeta" , /* tp_name */ |
1516 | sizeof(THPVariableMeta), /* tp_basicsize */ |
1517 | 0, /* tp_itemsize */ |
1518 | nullptr, /* tp_dealloc */ |
1519 | 0, /* tp_vectorcall_offset */ |
1520 | nullptr, /* tp_getattr */ |
1521 | nullptr, /* tp_setattr */ |
1522 | nullptr, /* tp_reserved */ |
1523 | nullptr, /* tp_repr */ |
1524 | nullptr, /* tp_as_number */ |
1525 | nullptr, /* tp_as_sequence */ |
1526 | nullptr, /* tp_as_mapping */ |
1527 | nullptr, /* tp_hash */ |
1528 | nullptr, /* tp_call */ |
1529 | nullptr, /* tp_str */ |
1530 | nullptr, /* tp_getattro */ |
1531 | nullptr, /* tp_setattro */ |
1532 | nullptr, /* tp_as_buffer */ |
1533 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
1534 | nullptr, /* tp_doc */ |
1535 | nullptr, /* tp_traverse */ |
1536 | nullptr, /* tp_clear */ |
1537 | nullptr, /* tp_richcompare */ |
1538 | 0, /* tp_weaklistoffset */ |
1539 | nullptr, /* tp_iter */ |
1540 | nullptr, /* tp_iternext */ |
1541 | nullptr, /* tp_methods */ |
1542 | nullptr, /* tp_members */ |
1543 | nullptr, /* tp_getset */ |
1544 | DEFERRED_ADDRESS(&PyType_Type), /* tp_base */ |
1545 | nullptr, /* tp_dict */ |
1546 | nullptr, /* tp_descr_get */ |
1547 | nullptr, /* tp_descr_set */ |
1548 | 0, /* tp_dictoffset */ |
1549 | THPVariableMetaType_init, /* tp_init */ |
1550 | nullptr, /* tp_alloc */ |
1551 | nullptr, /* tp_new */ |
1552 | }; |
1553 | |
1554 | PyTypeObject THPVariableType = { |
1555 | PyVarObject_HEAD_INIT( |
1556 | &THPVariableMetaType, |
1557 | 0) "torch._C._TensorBase" , /* tp_name */ |
1558 | sizeof(THPVariable), /* tp_basicsize */ |
1559 | 0, /* tp_itemsize */ |
1560 | // This is unspecified, because it is illegal to create a THPVariableType |
1561 | // directly. Subclasses will have their tp_dealloc set appropriately |
1562 | // by the metaclass |
1563 | nullptr, /* tp_dealloc */ |
1564 | 0, /* tp_vectorcall_offset */ |
1565 | nullptr, /* tp_getattr */ |
1566 | nullptr, /* tp_setattr */ |
1567 | nullptr, /* tp_reserved */ |
1568 | nullptr, /* tp_repr */ |
1569 | nullptr, /* tp_as_number */ |
1570 | nullptr, /* tp_as_sequence */ |
1571 | &THPVariable_as_mapping, /* tp_as_mapping */ |
1572 | nullptr, /* tp_hash */ |
1573 | nullptr, /* tp_call */ |
1574 | nullptr, /* tp_str */ |
1575 | nullptr, /* tp_getattro */ |
1576 | nullptr, /* tp_setattro */ |
1577 | nullptr, /* tp_as_buffer */ |
1578 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | |
1579 | Py_TPFLAGS_HAVE_GC, /* tp_flags */ |
1580 | nullptr, /* tp_doc */ |
1581 | // Also set by metaclass |
1582 | (traverseproc)THPFunction_traverse, /* tp_traverse */ |
1583 | (inquiry)THPVariable_clear, /* tp_clear */ |
1584 | nullptr, /* tp_richcompare */ |
1585 | 0, /* tp_weaklistoffset */ |
1586 | nullptr, /* tp_iter */ |
1587 | nullptr, /* tp_iternext */ |
1588 | nullptr, /* tp_methods */ |
1589 | nullptr, /* tp_members */ |
1590 | THPVariable_properties, /* tp_getset */ |
1591 | nullptr, /* tp_base */ |
1592 | nullptr, /* tp_dict */ |
1593 | nullptr, /* tp_descr_get */ |
1594 | nullptr, /* tp_descr_set */ |
1595 | 0, /* tp_dictoffset */ |
1596 | nullptr, /* tp_init */ |
1597 | nullptr, /* tp_alloc */ |
1598 | // Although new is provided here, it is illegal to call this with cls == |
1599 | // THPVariableMeta. Instead, subclass it first and then construct it |
1600 | THPVariable_pynew, /* tp_new */ |
1601 | }; |
1602 | |
1603 | PyObject* THPVariable_pynew( |
1604 | PyTypeObject* type, |
1605 | PyObject* args, |
1606 | PyObject* kwargs) { |
1607 | HANDLE_TH_ERRORS |
1608 | TORCH_CHECK( |
1609 | type != &THPVariableType, |
1610 | "Cannot directly construct _TensorBase; subclass it and then construct that" ); |
1611 | jit::tracer::warn("torch.Tensor" , jit::tracer::WARN_CONSTRUCTOR); |
1612 | auto tensor = torch::utils::base_tensor_ctor(args, kwargs); |
1613 | // WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was |
1614 | // given a raw pointer that will refcount bump |
1615 | // NB: base_tensor_ctor can call into dispatched ATen functions (e.g., |
1616 | // alias(), lift_fresh()) which can return Tensor subclasses. We allow |
1617 | // these to be passed on directly. |
1618 | return THPVariable_NewWithVar( |
1619 | type, |
1620 | std::move(tensor), |
1621 | c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED, |
1622 | /*allow_preexisting_pyobj=*/true); |
1623 | END_HANDLE_TH_ERRORS |
1624 | } |
1625 | |
1626 | static void clear_slots(PyTypeObject* type, PyObject* self) { |
1627 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1628 | Py_ssize_t i, n; |
1629 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1630 | PyMemberDef* mp; |
1631 | |
1632 | n = Py_SIZE(type); |
1633 | mp = type->tp_members; |
1634 | for (i = 0; i < n; i++, mp++) { |
1635 | if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) { |
1636 | char* addr = (char*)self + mp->offset; |
1637 | PyObject* obj = *(PyObject**)addr; |
1638 | if (obj != nullptr) { |
1639 | *(PyObject**)addr = nullptr; |
1640 | Py_DECREF(obj); |
1641 | } |
1642 | } |
1643 | } |
1644 | } |
1645 | |
1646 | // NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc |
1647 | // on subclasses. It's never valid to construct a THPVariable so it's not |
1648 | // necessary to implement the dealloc for that case |
1649 | void THPVariable_subclass_dealloc(PyObject* self) { |
1650 | if (THPVariable_tryResurrect((THPVariable*)self)) |
1651 | return; |
1652 | |
1653 | // This is like a crappy version of subtype_dealloc. |
1654 | // Unfortunately, we cannot directly delegate to |
1655 | // subtype_dealloc as it will start walking the parent |
1656 | // chain *starting with* the type of self, which will cause |
1657 | // us to go back to our custom dealloc. |
1658 | // |
1659 | // We have to replicate the subtype_dealloc logic to ensure |
1660 | // that finalizers are handled correctly |
1661 | PyTypeObject* type = Py_TYPE(self); |
1662 | TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); |
1663 | TORCH_INTERNAL_ASSERT(PyType_IS_GC(type), "GC types not implemented" ); |
1664 | |
1665 | PyObject_GC_UnTrack(self); |
1666 | // TODO: consider using trash can |
1667 | |
1668 | bool has_finalizer = type->tp_finalize || type->tp_del; |
1669 | |
1670 | if (type->tp_finalize) { |
1671 | PyObject_GC_Track(self); |
1672 | if (PyObject_CallFinalizerFromDealloc(self) < 0) { |
1673 | /* Resurrected */ |
1674 | return; |
1675 | } |
1676 | PyObject_GC_UnTrack(self); |
1677 | } |
1678 | |
1679 | // base test is unnecessary as THPVariable does not set this |
1680 | if (type->tp_weaklistoffset) { |
1681 | PyObject_ClearWeakRefs(self); |
1682 | } |
1683 | |
1684 | if (type->tp_del) { |
1685 | PyObject_GC_Track(self); |
1686 | type->tp_del(self); |
1687 | if (self->ob_refcnt > 0) { |
1688 | /* Resurrected */ |
1689 | return; |
1690 | } |
1691 | PyObject_GC_UnTrack(self); |
1692 | } |
1693 | |
1694 | if (has_finalizer) { |
1695 | /* New weakrefs could be created during the finalizer call. |
1696 | If this occurs, clear them out without calling their |
1697 | finalizers since they might rely on part of the object |
1698 | being finalized that has already been destroyed. */ |
1699 | if (type->tp_weaklistoffset) { |
1700 | /* Modeled after GET_WEAKREFS_LISTPTR() */ |
1701 | PyWeakReference** list = |
1702 | (PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self); |
1703 | while (*list) |
1704 | _PyWeakref_ClearRef(*list); |
1705 | } |
1706 | } |
1707 | |
1708 | // Clear all slots until we get to base class THPVariableType |
1709 | { |
1710 | PyTypeObject* base = type; |
1711 | while (base != &THPVariableType) { |
1712 | if (Py_SIZE(base)) { |
1713 | clear_slots(base, self); |
1714 | } |
1715 | base = base->tp_base; |
1716 | TORCH_INTERNAL_ASSERT(base); |
1717 | } |
1718 | } |
1719 | |
1720 | // All Python defined classes have __dict__ |
1721 | if (C10_LIKELY(type->tp_dictoffset)) { |
1722 | PyObject** dictptr = _PyObject_GetDictPtr(self); |
1723 | if (dictptr != nullptr) { |
1724 | PyObject* dict = *dictptr; |
1725 | if (dict != nullptr) { |
1726 | Py_DECREF(dict); |
1727 | *dictptr = nullptr; |
1728 | } |
1729 | } |
1730 | } |
1731 | |
1732 | // subtype_dealloc allows for this but we don't |
1733 | TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type); |
1734 | |
1735 | // Finally clear out the base THPVariable |
1736 | THPVariable_clear((THPVariable*)self); |
1737 | ((THPVariable*)self)->cdata.~MaybeOwned<Variable>(); |
1738 | Py_TYPE(self)->tp_free(self); |
1739 | |
1740 | // Python defined subclasses should always be on the heap |
1741 | TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); |
1742 | Py_DECREF(type); |
1743 | } |
1744 | |
1745 | // Creates a new Python object for a Variable. The status parameter |
1746 | // specifies what the interpreter tag status on the object is; for |
1747 | // example, if you ran check_pyobj, the return optional of this object |
1748 | // tells you if the tensor was already tagged or not so you can pass |
1749 | // TAGGED_BY_US or MAYBE_UNINITIALIZED; in other cases, you know where |
1750 | // var came from and can directly assert that it's DEFINITELY_UNINITIALIZED. |
1751 | // It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED. |
1752 | static PyObject* THPVariable_NewWithVar( |
1753 | PyTypeObject* type, |
1754 | Variable _var, |
1755 | c10::impl::PyInterpreterStatus status, |
1756 | bool allow_preexisting_pyobj) { |
1757 | // Make sure that the reinterpret into a THPVariable* will be valid |
1758 | TORCH_CHECK( |
1759 | PyType_IsSubtype(type, &THPVariableType), |
1760 | "Creating a Tensor subclass from a class " , |
1761 | "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor." ); |
1762 | |
1763 | // This function overwrite the Tensor's pyobj field without extra checks |
1764 | // Make sure it is not set otherwise we would leak memory |
1765 | auto mb_obj = |
1766 | _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter()); |
1767 | |
1768 | // Under some circumstances, we may attempt to create a new Python |
1769 | // object for a variable that already has a Python object. The most common |
1770 | // situation this can occur is if you have a TorchDispatchMode active that |
1771 | // is returning a subclass from lift_fresh (which is invoked to |
1772 | // appropriately "wrap" a constant tensor into whatever ambient modes are |
1773 | // active.) |
1774 | // |
1775 | // In general, it is impossible to handle this case compositionally. |
1776 | // Suppose you have a user call ATensor([1, 2, 3]) when a mode is active |
1777 | // that is transforming all ops (including the internal lift_fresh call that |
1778 | // transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output |
1779 | // BTensor, where ATensor and BTensor are completely unrelated subclasses |
1780 | // and there is no way to compose them. There is no way to satisfy the user |
1781 | // request here: in particular, you can't just try to re-invoke the ATensor |
1782 | // constructor on the returned BTensor, because (1) this could cause an |
1783 | // infinite loop--we are already in ATensor.__new__ and (2) there isn't any |
1784 | // guarantee that ATensor.__new__ supports a single element constructor |
1785 | // anyway. |
1786 | // |
1787 | // However, a more common case is a user just called torch.Tensor([1, 2, 3]), |
1788 | // and a fake tensor mode is active. Really, all you want is to get back |
1789 | // a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3) |
1790 | // would have returned a fake tensor (concretely, the way this happens |
1791 | // is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it |
1792 | // turns into a FakeTensor when we call lift_fresh on this real tensor). |
1793 | // This case is compositional because FakeTensor is a subclass of Tensor, so |
1794 | // it's valid for us to return it in place of a Tensor. So this is what we |
1795 | // do. |
1796 | |
1797 | if (mb_obj.has_value() && mb_obj.value()) { |
1798 | TORCH_CHECK( |
1799 | allow_preexisting_pyobj, |
1800 | "Creating a new Tensor subclass " , |
1801 | type->tp_name, |
1802 | " but the raw Tensor object is already associated to a python object " , |
1803 | "of type " , |
1804 | mb_obj.value()->ob_type->tp_name); |
1805 | // Even if we allow pre-existing PyObject, we don't allow completely |
1806 | // ignoring the requested type. Check that we fulfilled a subtype |
1807 | // relation here. In the common case the requested type is Tensor and |
1808 | // this always succeeds. |
1809 | PyObject* obj = *mb_obj; |
1810 | // Check if it's OK to just directly return the Python object without |
1811 | // allocating a new variable. We just check that the existing Python |
1812 | // object is a subclass of the requested type. |
1813 | PyTypeObject* obj_type = Py_TYPE(obj); |
1814 | TORCH_CHECK( |
1815 | obj_type == type || PyType_IsSubtype(obj_type, type), |
1816 | "Creating a new Tensor subclass " , |
1817 | type->tp_name, |
1818 | " but the raw Tensor object is already associated to a python object " , |
1819 | "of type " , |
1820 | mb_obj.value()->ob_type->tp_name, |
1821 | " which is not a subclass of the " |
1822 | "requested type" ); |
1823 | // We may (in fact, we typically will) need to resurrect this |
1824 | return THPVariable_Wrap(std::move(_var)); |
1825 | } |
1826 | |
1827 | PyObject* obj = type->tp_alloc(type, 0); |
1828 | if (obj) { |
1829 | auto v = (THPVariable*)obj; |
1830 | // TODO: named constructor to avoid default initialization |
1831 | new (&v->cdata) MaybeOwned<Variable>(); |
1832 | if (c10::impl::HermeticPyObjectTLS::get_state()) { |
1833 | // Do NOT initialize pyobj field on the tensor, you own the C++ |
1834 | v->cdata = MaybeOwned<Variable>::owned(std::move(_var)); |
1835 | TORCH_INTERNAL_ASSERT( |
1836 | !check_has_torch_dispatch(obj), |
1837 | "While HermeticPyObject was enabled, we attempted to create a tensor " |
1838 | "subclass with __torch_dispatch__. This violates the invariant that " |
1839 | "operations in HermeticPyObject have equivalent C++ implementations. " |
1840 | "If your operator registered from Python operator registration isn't " |
1841 | "doing anything strange, there may be an internal PyTorch bug involving " |
1842 | "not appropriately disabling TorchDispatchMode before executing " |
1843 | "Python op registration." ); |
1844 | } else { |
1845 | // Normal codepath |
1846 | v->cdata = MaybeOwned<Variable>::owned(std::move(_var)); |
1847 | const auto& var = THPVariable_Unpack(v); |
1848 | var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( |
1849 | getPyInterpreter(), obj, status); |
1850 | if (check_has_torch_dispatch(obj)) { |
1851 | var.unsafeGetTensorImpl()->set_python_dispatch(true); |
1852 | } |
1853 | } |
1854 | } |
1855 | return obj; |
1856 | } |
1857 | |
1858 | /// NOTE [ PyObject Traversal ] |
1859 | /// |
1860 | /// PyObjects that are wrapping c++ objects can lead to non-trivial traverse |
1861 | /// logic and it can be tricky to know what to traverse and when. This note |
1862 | /// tries to clarify what is the danger here and a simple algorithm to choose |
1863 | /// how to write the tp_traverse and tp_clear functions. If you're not already |
1864 | /// familiar with how the CPython GC works, you should read this in-depth |
1865 | /// description: https://devguide.python.org/garbage_collector/ |
1866 | /// |
1867 | /// The complexity for us comes from the fact that some c++ shared_ptr objects |
1868 | /// own references to python objects and are also owned both by other python |
1869 | /// objects and c++ objects. This means that to allow the GC to collect all |
1870 | /// cycles, we need to properly implement the traverse/clear methods that take |
1871 | /// into account these C++ ownership links. |
1872 | /// |
1873 | /// The main danger here comes from the fact that, while all python-related code |
1874 | /// is thread safe wrt the GC execution (thanks to the GIL), other threads might |
1875 | /// be using our C++ objects arbitrarily which can lead to shared_ptr ref count |
1876 | /// going up or down in between the different traverse/clear invocations. The |
1877 | /// one constraint we add here that is not explicitly mentioned in the GC |
1878 | /// description above is that for a given GC run (meaning while the GIL is |
1879 | /// held), the traverse/clear pair should never report different ownership |
1880 | /// relations: if traverse visited a given PyObject, then the clear within that |
1881 | /// same GC run must still be the sole owner and clear that PyObject. |
1882 | /// |
1883 | /// A more mechanical algorithm to know what to traverse/clear is as follows: |
1884 | /// - Any field on this PyObject that contains a strong reference to another |
1885 | /// PyObject |
1886 | /// must be visited and cleared. An example of that is the "backward_hooks" |
1887 | /// field of the THPVariable. |
1888 | /// - Any field that contains a C++ object that is uniquely owned by this |
1889 | /// PyObject (either |
1890 | /// a unique_ptr or a shared_ptr with use_count==1) should have all the |
1891 | /// PyObject it owns visited and cleared. An example would be here the |
1892 | /// tensor hooks. |
1893 | /// - If that uniquely owned C++ object also uniquely owns other C++ objects, |
1894 | /// these should be |
1895 | /// visited and cleared as well if they contain any PyObject. |
1896 | /// |
1897 | /// Caveat: to avoid slow runtime, we limit the depth of this exploration of C++ |
1898 | /// objects in practice and we do not, for example, go through the whole |
1899 | /// autograd graph, even if it is uniquely owned. This is a known place where |
1900 | /// users can create noncollectable cycles as described in: |
1901 | /// https://github.com/pytorch/pytorch/issues/7343 |
1902 | /// |
1903 | |
1904 | static int traverse_slots( |
1905 | PyTypeObject* type, |
1906 | PyObject* self, |
1907 | visitproc visit, |
1908 | void* arg) { |
1909 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1910 | Py_ssize_t i, n; |
1911 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1912 | PyMemberDef* mp; |
1913 | |
1914 | n = Py_SIZE(type); |
1915 | mp = type->tp_members; |
1916 | for (i = 0; i < n; i++, mp++) { |
1917 | if (mp->type == T_OBJECT_EX) { |
1918 | char* addr = (char*)self + mp->offset; |
1919 | PyObject* obj = *(PyObject**)addr; |
1920 | if (obj != nullptr) { |
1921 | int err = visit(obj, arg); |
1922 | if (err) |
1923 | return err; |
1924 | } |
1925 | } |
1926 | } |
1927 | return 0; |
1928 | } |
1929 | |
1930 | static int THPVariable_subclass_traverse( |
1931 | PyObject* self, |
1932 | visitproc visit, |
1933 | void* arg) { |
1934 | // If the tensor is eligible to be resurrected, don't traverse it; instead |
1935 | // treat all of its references as a root (as they WOULD be a root since we |
1936 | // can treat the inbound C++ references as root owners). |
1937 | // |
1938 | // This works because unlike conventional GCs, Python's GC operates in two |
1939 | // phases: first it uses traverse to discover roots, and then it uses traverse |
1940 | // to do reachability. Bypassing traverse during root discovery forces Python |
1941 | // to treat self as a root for everything it refers to. For a full |
1942 | // explanation of the algorithm see |
1943 | // https://devguide.python.org/garbage_collector/ |
1944 | // |
1945 | // NB: if we don't hold an owning reference to the underlying Tensor, it is |
1946 | // possible that the underlying Tensor has already gone dead. In that case, |
1947 | // it's not safe to access it. But it's also safe to traverse, because if |
1948 | // the underlying Tensor *is* live, then root discovery will determine that |
1949 | // self is live, and nothing will get GC'ed anyway (resurrection cannot happen |
1950 | // if the C++ objects owns the PyObject) |
1951 | THPVariable* var = reinterpret_cast<THPVariable*>(self); |
1952 | if (isResurrectable(var)) { |
1953 | return 0; |
1954 | } |
1955 | |
1956 | // Crappy version of subtype_traverse; same deal as |
1957 | // THPVariable_subclass_dealloc |
1958 | |
1959 | PyTypeObject* type = Py_TYPE(self); |
1960 | // Traverse slots until we get to base class THPVariableType |
1961 | { |
1962 | PyTypeObject* base = type; |
1963 | while (base != &THPVariableType) { |
1964 | if (Py_SIZE(base)) { |
1965 | int err = traverse_slots(base, self, visit, arg); |
1966 | if (err) |
1967 | return err; |
1968 | } |
1969 | base = base->tp_base; |
1970 | TORCH_INTERNAL_ASSERT(base); |
1971 | } |
1972 | } |
1973 | |
1974 | // All Python defined classes have __dict__ |
1975 | if (C10_LIKELY(type->tp_dictoffset)) { |
1976 | PyObject** dictptr = _PyObject_GetDictPtr(self); |
1977 | if (dictptr && *dictptr) |
1978 | Py_VISIT(*dictptr); |
1979 | } |
1980 | |
1981 | TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); |
1982 | Py_VISIT(type); |
1983 | |
1984 | // Finally traverse THPVariable special stuff |
1985 | Py_VISIT(var->backward_hooks); |
1986 | if (!var->cdata.unsafeIsBorrowed()) { |
1987 | const auto& tensor = THPVariable_Unpack(var); |
1988 | if (tensor.defined()) { |
1989 | // WARNING: The grad_fn traversal logic is very subtle, if you change |
1990 | // this, be very careful not to re-introduce this bug: |
1991 | // https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c |
1992 | |
1993 | // We ensure that we follow NOTE [ PyObject Traversal ] he by checking |
1994 | // that this python object is the sole owner of the underlying Tensor and |
1995 | // that this Tensor is the sole owner of its grad_fn. In this case, the |
1996 | // only way to get a new reference to the grad_fn is by using this python |
1997 | // object, which requires the GIL to be accessed. Note that this is only |
1998 | // valid as long as user don't share non-owning references across |
1999 | // different threads (which is crazy and should never be done). |
2000 | auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor); |
2001 | if (tensor.use_count() == 1) { |
2002 | if (autograd_meta) { |
2003 | // Do NOT call grad_fn() here as that might trigger a recompute |
2004 | const auto& grad_fn = autograd_meta->grad_fn_; |
2005 | if (grad_fn && grad_fn.use_count() == 1) { |
2006 | // All Node can have a pyobj (stored in "pyobj_") |
2007 | Py_VISIT(grad_fn->pyobj()); |
2008 | // PyNode are special as they also have an "obj" field |
2009 | if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) { |
2010 | Py_VISIT(py_node_fn->obj); |
2011 | } |
2012 | } |
2013 | } |
2014 | } |
2015 | if (autograd_meta) { |
2016 | for (const auto& hook : torch::autograd::impl::hooks(tensor)) { |
2017 | if (auto pyhook = |
2018 | dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) { |
2019 | Py_VISIT(pyhook->dict); |
2020 | } |
2021 | } |
2022 | } |
2023 | } |
2024 | } |
2025 | |
2026 | return 0; |
2027 | } |
2028 | |
2029 | int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) { |
2030 | if (PyType_Type.tp_init(cls, args, kwargs) < 0) { |
2031 | return -1; |
2032 | } |
2033 | ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc; |
2034 | ((PyTypeObject*)cls)->tp_traverse = |
2035 | (traverseproc)THPVariable_subclass_traverse; |
2036 | return 0; |
2037 | } |
2038 | |
2039 | namespace torch { |
2040 | namespace autograd { |
2041 | |
2042 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
2043 | extern PyMethodDef variable_methods[]; |
2044 | extern void initTorchFunctions(PyObject* module); |
2045 | |
2046 | void initTensorImplConversion(PyObject* module) { |
2047 | auto m = py::handle(module).cast<py::module>(); |
2048 | m.def("_wrap_tensor_impl" , [](void* ptr) { |
2049 | auto p = c10::intrusive_ptr<c10::TensorImpl, at::UndefinedTensorImpl>:: |
2050 | unsafe_reclaim_from_nonowning(static_cast<c10::TensorImpl*>(ptr)); |
2051 | TORCH_CHECK(p.defined(), "Can't wrap undefined tensor" ); |
2052 | auto tensor = at::Tensor::wrap_tensor_impl(std::move(p)); |
2053 | return py::cast(std::move(tensor)); |
2054 | }); |
2055 | // set on the module level to avoid mixing pybind and plain CPython extensions |
2056 | m.def("_tensor_impl_raw_handle" , [](torch::autograd::Variable* t) -> void* { |
2057 | // We return a raw non-owning pointer here, we rely on surrounding |
2058 | // code to keep the original tensor alive |
2059 | return t->getIntrusivePtr().get(); |
2060 | }); |
2061 | } |
2062 | } // namespace autograd |
2063 | } // namespace torch |
2064 | |
2065 | bool THPVariable_initModule(PyObject* module) { |
2066 | THPVariableMetaType.tp_base = &PyType_Type; |
2067 | if (PyType_Ready(&THPVariableMetaType) < 0) |
2068 | return false; |
2069 | Py_INCREF(&THPVariableMetaType); |
2070 | PyModule_AddObject(module, "_TensorMeta" , (PyObject*)&THPVariableMetaType); |
2071 | |
2072 | static std::vector<PyMethodDef> methods; |
2073 | THPUtils_addPyMethodDefs(methods, torch::autograd::variable_methods); |
2074 | THPUtils_addPyMethodDefs(methods, extra_methods); |
2075 | THPVariableType.tp_methods = methods.data(); |
2076 | if (PyType_Ready(&THPVariableType) < 0) |
2077 | return false; |
2078 | Py_INCREF(&THPVariableType); |
2079 | PyModule_AddObject(module, "_TensorBase" , (PyObject*)&THPVariableType); |
2080 | torch::autograd::initTorchFunctions(module); |
2081 | torch::autograd::initTensorImplConversion(module); |
2082 | torch::utils::validate_numpy_for_dlpack_deleter_bug(); |
2083 | return true; |
2084 | } |
2085 | |