1 | #include <ATen/core/PythonFallbackKernel.h> |
2 | #include <ATen/core/PythonOpRegistrationTrampoline.h> |
3 | #include <torch/csrc/PyInterpreter.h> |
4 | #include <torch/csrc/THP.h> |
5 | #include <torch/csrc/autograd/generated/VariableType.h> |
6 | #include <torch/csrc/utils/python_arg_parser.h> |
7 | #include <torch/csrc/utils/python_dispatch.h> |
8 | |
9 | #include <string> |
10 | |
11 | using namespace torch; |
12 | using namespace at; |
13 | using namespace c10; |
14 | |
15 | namespace { |
16 | |
17 | // NB: This is a macro and not a template function (like it was before) |
18 | // because passing in constexpr char* as template argument breaks some |
19 | // versions of MSVC that are being used internally at Meta. |
20 | // MSVC 14.16.27023 (vs2017_15.9) |
21 | #define CONCRETE_TRACE_CUDA(func_name, ...) \ |
22 | at::impl::MaybeSetTLSOnEntryGuard guard; \ |
23 | if (Py_IsInitialized()) { \ |
24 | pybind11::gil_scoped_acquire gil; \ |
25 | try { \ |
26 | py::module mod = py::module::import("torch.utils._cuda_trace"); \ |
27 | py::object hook = mod.attr(func_name).attr("fire_callbacks"); \ |
28 | hook(__VA_ARGS__); \ |
29 | } catch (const std::exception& e) { \ |
30 | LOG(ERROR) << "CUDA trace hook execution failed: " << e.what(); \ |
31 | } \ |
32 | } |
33 | |
34 | struct ConcretePyInterpreterVTable final |
35 | : public c10::impl::PyInterpreterVTable { |
36 | std::string name() const override; |
37 | |
38 | void decref(PyObject* pyobj, bool is_tensor) const override; |
39 | |
40 | // TODO: Need to make this work for StorageImpl too. I imagine I'll want to |
41 | // operate upon a PyObjectSlot rather than a TensorImpl |
42 | c10::intrusive_ptr<c10::TensorImpl> detach( |
43 | const c10::TensorImpl* self) const override; |
44 | |
45 | void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) |
46 | const override; |
47 | void python_dispatcher( |
48 | const c10::OperatorHandle& op, |
49 | c10::DispatchKeySet, |
50 | torch::jit::Stack* stack) const override; |
51 | // NB: this is defined in python_dispatch.cpp |
52 | void python_op_registration_trampoline( |
53 | const c10::OperatorHandle& op, |
54 | c10::DispatchKey key, |
55 | torch::jit::Stack* stack) const override { |
56 | torch::impl::dispatch::python_op_registration_trampoline_impl( |
57 | op, key, stack); |
58 | } |
59 | |
60 | bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat) |
61 | const override; |
62 | bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat) |
63 | const override; |
64 | bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override; |
65 | c10::Device device(const c10::TensorImpl* self) const override; |
66 | int64_t dim(const c10::TensorImpl* self) const override; |
67 | c10::IntArrayRef strides(const c10::TensorImpl* self) const override; |
68 | c10::IntArrayRef sizes(const c10::TensorImpl* self) const override; |
69 | c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override; |
70 | c10::Layout layout(const c10::TensorImpl* self) const override; |
71 | c10::SymInt sym_numel(const c10::TensorImpl* self) const override; |
72 | c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override; |
73 | c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override; |
74 | |
75 | void trace_gpu_event_creation(uintptr_t event) const override { |
76 | CONCRETE_TRACE_CUDA("CUDAEventCreationCallbacks" , event); |
77 | } |
78 | void trace_gpu_event_deletion(uintptr_t event) const override { |
79 | CONCRETE_TRACE_CUDA("CUDAEventDeletionCallbacks" , event); |
80 | } |
81 | void trace_gpu_event_record(uintptr_t event, uintptr_t stream) |
82 | const override { |
83 | CONCRETE_TRACE_CUDA("CUDAEventRecordCallbacks" , event, stream); |
84 | } |
85 | void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override { |
86 | CONCRETE_TRACE_CUDA("CUDAEventWaitCallbacks" , event, stream); |
87 | } |
88 | void trace_gpu_memory_allocation(uintptr_t ptr) const override { |
89 | CONCRETE_TRACE_CUDA("CUDAMemoryAllocationCallbacks" , ptr); |
90 | } |
91 | void trace_gpu_memory_deallocation(uintptr_t ptr) const override { |
92 | CONCRETE_TRACE_CUDA("CUDAMemoryDeallocationCallbacks" , ptr); |
93 | } |
94 | void trace_gpu_stream_creation(uintptr_t stream) const override { |
95 | CONCRETE_TRACE_CUDA("CUDAStreamCreationCallbacks" , stream); |
96 | } |
97 | void trace_gpu_device_synchronization() const override { |
98 | CONCRETE_TRACE_CUDA("CUDADeviceSynchronizationCallbacks" ); |
99 | } |
100 | void trace_gpu_stream_synchronization(uintptr_t stream) const override { |
101 | CONCRETE_TRACE_CUDA("CUDAStreamSynchronizationCallbacks" , stream); |
102 | } |
103 | void trace_gpu_event_synchronization(uintptr_t event) const override { |
104 | CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks" , event); |
105 | } |
106 | |
107 | void reset_backward_hooks(const c10::TensorImpl* self) const override; |
108 | |
109 | static ConcretePyInterpreterVTable* instance() { |
110 | static ConcretePyInterpreterVTable s; |
111 | return &s; |
112 | } |
113 | }; |
114 | |
115 | class PyInterpreterHolder { |
116 | public: |
117 | PyInterpreterHolder() |
118 | : impl_(new c10::impl::PyInterpreter( |
119 | ConcretePyInterpreterVTable::instance())) { |
120 | is_main_interpreter_ = |
121 | at::impl::PythonOpRegistrationTrampoline::registerInterpreter(impl_); |
122 | } |
123 | // NB: intentionally leaks the PyInterpreter, as there may still be |
124 | // references to it that are live, living in objects that aren't being |
125 | // destructed while Python is being cleaned up. |
126 | ~PyInterpreterHolder() { |
127 | impl_->disarm(); |
128 | } |
129 | c10::impl::PyInterpreter* get() const noexcept { |
130 | return impl_; |
131 | } |
132 | bool is_main_interpreter() const noexcept { |
133 | return is_main_interpreter_; |
134 | } |
135 | |
136 | private: |
137 | c10::impl::PyInterpreter* impl_; |
138 | bool is_main_interpreter_; |
139 | }; |
140 | |
141 | py::object torchDispatchFromTensorImpl( |
142 | const c10::TensorImpl* self, |
143 | const char* func_name, |
144 | PyObject* torch_api_function, |
145 | const char* module_name, |
146 | // WARNING: MUST NOT BE TENSOR ARGS |
147 | c10::SmallVector<py::object, 1> = {}) { |
148 | if (torch_api_function == nullptr) { |
149 | throw python_error(); |
150 | } |
151 | TORCH_CHECK( |
152 | PyGILState_Check(), |
153 | "GIL must be held before you call parseIValuesToPyArgsKwargs" ); |
154 | |
155 | std::vector<py::handle> overloaded_args; |
156 | // TODO: there should be a shorter way to spell this |
157 | // TODO: fix the constness of target |
158 | at::Tensor self_t = at::Tensor( |
159 | c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>:: |
160 | unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self))); |
161 | auto self_p = |
162 | py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t))); |
163 | // NB: this may not be a python tensor if you got here from a mode! |
164 | // TORCH_INTERNAL_ASSERT(isPythonTensor(self_t)); |
165 | append_overloaded_tensor(&overloaded_args, self_p.ptr()); |
166 | auto args = |
167 | py::reinterpret_steal<py::object>(PyTuple_New(1 + extra_args.size())); |
168 | PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr()); |
169 | int64_t i = 1; |
170 | for (auto& a : extra_args) { |
171 | if (a.ptr() == nullptr) |
172 | throw python_error(); |
173 | PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr()); |
174 | i++; |
175 | } |
176 | |
177 | py::dict kwargs; |
178 | |
179 | return py::reinterpret_steal<py::object>( |
180 | handle_torch_function_no_python_arg_parser( |
181 | overloaded_args, |
182 | args.ptr(), |
183 | kwargs.ptr(), |
184 | func_name, |
185 | torch_api_function, |
186 | module_name, |
187 | TorchFunctionName::TorchDispatch)); |
188 | } |
189 | |
190 | // NOTE [PyInterpreter::decref takes an `is_tensor` arg] |
191 | // Before calling PyInterpreter::decref, we must statically know if the |
192 | // pyobj is a Tensor or not. |
193 | // - If it is a tensor, we need to be careful about PyObject resurrection |
194 | // - If it is not a tensor, we can freely decref |
195 | // One alternative to this is using PyObject_IsInstance |
196 | // to get at this information. However, we don't want to risk an incorrect |
197 | // `__instancecheck__` changing the semantics here. |
198 | void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor) |
199 | const { |
200 | // Leak the pyobj if not initialized. This can happen if we are running |
201 | // exit handlers that are destructing tensors with residual (owned) |
202 | // PyObjects stored in them. |
203 | if (!Py_IsInitialized()) |
204 | return; |
205 | |
206 | pybind11::gil_scoped_acquire gil; |
207 | // Two possibilities: |
208 | // 1. We are decref-ing a tensor. Then we must be careful about |
209 | // PyObject resurrection (this only applies to Tensors, see |
210 | // THPVariable_clear). |
211 | // 2. We are decref-ing some other Python object. We don't do |
212 | // PyObject resurrection on non-Tensors, so we just carry on as usual |
213 | if (is_tensor && Py_REFCNT(pyobj) > 1) { |
214 | // It's still alive! This can happen if a weak ref resurrected |
215 | // the PyObject without flipping ownership. At this point it is |
216 | // too late to rescue the object, so just stub out the PyObject |
217 | // so that it fails on subsequent uses. Don't raise an error here; |
218 | // you're probably in a destructor. |
219 | TORCH_WARN( |
220 | "Deallocating Tensor that still has live PyObject references. " |
221 | "This probably happened because you took out a weak reference to " |
222 | "Tensor and didn't call _fix_weakref() after dereferencing it. " |
223 | "Subsequent accesses to this tensor via the PyObject will now fail." ); |
224 | ((THPVariable*)pyobj)->cdata = c10::MaybeOwned<torch::autograd::Variable>(); |
225 | } |
226 | Py_DECREF(pyobj); |
227 | }; |
228 | |
229 | py::handle getTorchApiFunction(const c10::OperatorHandle& op) { |
230 | return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* { |
231 | // Parse the name into namespace and name (no overload_name) |
232 | // TODO: put this into the library |
233 | const auto& schema = op.schema(); |
234 | const auto& qualified_name = op.operator_name().name; |
235 | const auto& overload_name = schema.overload_name(); |
236 | auto pos = qualified_name.find("::" ); |
237 | TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); |
238 | // Make me some null terminated strings |
239 | std::string ns_str = qualified_name.substr(0, pos); |
240 | const char* ns = ns_str.c_str(); |
241 | const char* func_name = qualified_name.c_str() + pos + strlen("::" ); |
242 | |
243 | py::handle torch_api_function = |
244 | py::module::import("torch" ).attr("ops" ).attr(ns).attr(func_name); |
245 | if (overload_name.empty()) { |
246 | return torch_api_function.attr("default" ).ptr(); |
247 | } else { |
248 | return torch_api_function.attr(overload_name.c_str()).ptr(); |
249 | } |
250 | }); |
251 | } |
252 | |
253 | bool isPythonTensor(const at::Tensor& tensor) { |
254 | return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); |
255 | } |
256 | |
257 | void ConcretePyInterpreterVTable::dispatch( |
258 | const c10::OperatorHandle& op, |
259 | torch::jit::Stack* stack) const { |
260 | const auto& schema = op.schema(); |
261 | const auto num_arguments = schema.arguments().size(); |
262 | auto arguments = torch::jit::pop(*stack, num_arguments); |
263 | |
264 | // The plan: convert all the arguments back into PyObjects, |
265 | // extracting out the tensor handles, then call |
266 | // handle_torch_function_no_python_arg_parser |
267 | // NB: at the point arguments are pushed to the stack, ALL defaults |
268 | // are already present |
269 | |
270 | py::gil_scoped_acquire g; |
271 | |
272 | std::vector<py::handle> overloaded_args; |
273 | py::handle torch_api_function_overload = getTorchApiFunction(op); |
274 | |
275 | // Find overloaded tensors |
276 | for (const auto idx : c10::irange(arguments.size())) { |
277 | const auto& ivalue = arguments[idx]; |
278 | if (ivalue.isTensor()) { |
279 | const auto& tensor = ivalue.toTensor(); |
280 | if (isPythonTensor(tensor)) { |
281 | append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); |
282 | } |
283 | } else if (ivalue.isList()) { |
284 | const auto& list = ivalue.toListRef(); |
285 | for (const auto jdx : c10::irange(list.size())) { |
286 | const auto& nv = list[jdx]; |
287 | if (nv.isTensor()) { |
288 | const auto& tensor = nv.toTensor(); |
289 | if (isPythonTensor(tensor)) { |
290 | append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); |
291 | } |
292 | } |
293 | } |
294 | } |
295 | } |
296 | |
297 | auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); |
298 | auto args = std::move(args_kwargs.first); |
299 | auto kwargs = std::move(args_kwargs.second); |
300 | |
301 | PyObject* obj = handle_torch_function_no_python_arg_parser( |
302 | overloaded_args, |
303 | args.ptr(), |
304 | kwargs.ptr(), |
305 | nullptr, |
306 | torch_api_function_overload.ptr(), |
307 | nullptr, |
308 | TorchFunctionName::TorchDispatch); |
309 | pushPyOutToStack( |
310 | op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__" ); |
311 | } |
312 | |
313 | void ConcretePyInterpreterVTable::python_dispatcher( |
314 | const c10::OperatorHandle& op, |
315 | c10::DispatchKeySet ks, |
316 | torch::jit::Stack* stack) const { |
317 | py::gil_scoped_acquire g; |
318 | py::handle torch_api_function_overload = getTorchApiFunction(op); |
319 | // TODO: if necessary, can optimize to cache the cache lookup |
320 | // TODO: if necessary, can optimize OpOverload to have slots |
321 | auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache" )); |
322 | if (cache.ptr() == nullptr) { |
323 | throw python_error(); |
324 | } |
325 | |
326 | c10::DispatchKey k = ks.highestPriorityTypeId(); |
327 | // TODO: allow this to be non-owning |
328 | auto handler = py::reinterpret_borrow<py::object>( |
329 | PyDict_GetItem(cache.ptr(), py::cast(k).ptr())); |
330 | if (handler.ptr() == nullptr) { |
331 | // Slow path |
332 | handler = torch_api_function_overload.attr("_get_dispatch" )(k); |
333 | } |
334 | if (py::isinstance<c10::DispatchKey>(handler)) { |
335 | // NB: not redispatch, as that will permanently remove the python |
336 | // dispatcher for subsequent redispatches |
337 | op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack); |
338 | return; |
339 | } |
340 | |
341 | const auto& schema = op.schema(); |
342 | const auto num_arguments = schema.arguments().size(); |
343 | auto arguments = torch::jit::pop(*stack, num_arguments); |
344 | |
345 | auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); |
346 | auto args = std::move(args_kwargs.first); |
347 | auto kwargs = std::move(args_kwargs.second); |
348 | |
349 | py::object obj = py::reinterpret_steal<py::object>( |
350 | PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr())); |
351 | |
352 | if (obj.ptr() == nullptr) { |
353 | throw python_error(); |
354 | } |
355 | |
356 | pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher" ); |
357 | } |
358 | |
359 | c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach( |
360 | const c10::TensorImpl* self) const { |
361 | pybind11::gil_scoped_acquire gil; |
362 | at::impl::MaybeSetTLSOnEntryGuard guard; |
363 | |
364 | auto out = torchDispatchFromTensorImpl( |
365 | self, |
366 | "detach" , |
367 | py::module::import("torch" ) |
368 | .attr("ops" ) |
369 | .attr("aten" ) |
370 | .attr("detach" ) |
371 | .attr("default" ) |
372 | .ptr(), |
373 | "torch.ops.aten" ); |
374 | |
375 | TORCH_CHECK( |
376 | THPVariable_Check(out.ptr()), |
377 | "detach returned invalid type " , |
378 | py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
379 | ", expected Tensor" ); |
380 | const at::Tensor& res_t = THPVariable_Unpack(out.ptr()); |
381 | return res_t.getIntrusivePtr(); |
382 | } |
383 | |
384 | bool ConcretePyInterpreterVTable::is_contiguous( |
385 | const c10::TensorImpl* self, |
386 | at::MemoryFormat memory_format) const { |
387 | pybind11::gil_scoped_acquire gil; |
388 | at::impl::MaybeSetTLSOnEntryGuard guard; |
389 | |
390 | py::object out; |
391 | if (memory_format == at::MemoryFormat::Contiguous) { |
392 | // For backwards compatibility |
393 | out = torchDispatchFromTensorImpl( |
394 | self, |
395 | "is_contiguous" , |
396 | py::module::import("torch" ) |
397 | .attr("ops" ) |
398 | .attr("aten" ) |
399 | .attr("is_contiguous" ) |
400 | .attr("default" ) |
401 | .ptr(), |
402 | "torch.ops.aten" ); |
403 | } else { |
404 | out = torchDispatchFromTensorImpl( |
405 | self, |
406 | "is_contiguous" , |
407 | py::module::import("torch" ) |
408 | .attr("ops" ) |
409 | .attr("aten" ) |
410 | .attr("is_contiguous" ) |
411 | .attr("memory_format" ) |
412 | .ptr(), |
413 | "torch.ops.aten" , |
414 | {py::cast(memory_format)}); |
415 | } |
416 | |
417 | if (out.is_none()) { |
418 | return self->is_contiguous_default(memory_format); |
419 | } |
420 | |
421 | TORCH_CHECK( |
422 | PyBool_Check(out.ptr()), |
423 | "is_contiguous returned invalid type " , |
424 | py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
425 | ", expected bool" ); |
426 | |
427 | return PyObject_IsTrue(out.ptr()); |
428 | } |
429 | |
430 | bool ConcretePyInterpreterVTable::is_strides_like( |
431 | const c10::TensorImpl* self, |
432 | at::MemoryFormat memory_format) const { |
433 | pybind11::gil_scoped_acquire gil; |
434 | at::impl::MaybeSetTLSOnEntryGuard guard; |
435 | |
436 | auto out = torchDispatchFromTensorImpl( |
437 | self, |
438 | "is_strides_like" , |
439 | py::module::import("torch" ) |
440 | .attr("ops" ) |
441 | .attr("aten" ) |
442 | // NB: intentionally suffixed with _format to avoid |
443 | // triggering matches against "_like" suffix |
444 | .attr("is_strides_like_format" ) |
445 | .attr("default" ) |
446 | .ptr(), |
447 | "torch.ops.aten" , |
448 | {py::cast(memory_format)}); |
449 | |
450 | if (out.is_none()) { |
451 | return self->is_strides_like_default(memory_format); |
452 | } |
453 | |
454 | TORCH_CHECK( |
455 | PyBool_Check(out.ptr()), |
456 | "is_strides_like_format returned invalid type " , |
457 | py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
458 | ", expected bool" ); |
459 | |
460 | return PyObject_IsTrue(out.ptr()); |
461 | } |
462 | |
463 | bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense( |
464 | const c10::TensorImpl* self) const { |
465 | pybind11::gil_scoped_acquire gil; |
466 | at::impl::MaybeSetTLSOnEntryGuard guard; |
467 | |
468 | auto out = torchDispatchFromTensorImpl( |
469 | self, |
470 | "is_non_overlapping_and_dense" , |
471 | py::module::import("torch" ) |
472 | .attr("ops" ) |
473 | .attr("aten" ) |
474 | .attr("is_non_overlapping_and_dense" ) |
475 | .attr("default" ) |
476 | .ptr(), |
477 | "torch.ops.aten" ); |
478 | |
479 | if (out.is_none()) { |
480 | return self->is_non_overlapping_and_dense_default(); |
481 | } |
482 | |
483 | TORCH_CHECK( |
484 | PyBool_Check(out.ptr()), |
485 | "is_non_overlapping_and_dense returned invalid type " , |
486 | py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
487 | ", expected bool" ); |
488 | |
489 | return PyObject_IsTrue(out.ptr()); |
490 | } |
491 | |
492 | int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const { |
493 | pybind11::gil_scoped_acquire gil; |
494 | at::impl::MaybeSetTLSOnEntryGuard guard; |
495 | |
496 | auto out = torchDispatchFromTensorImpl( |
497 | self, |
498 | "dim" , |
499 | py::module::import("torch" ) |
500 | .attr("ops" ) |
501 | .attr("aten" ) |
502 | .attr("dim" ) |
503 | .attr("default" ) |
504 | .ptr(), |
505 | "torch.ops.aten" ); |
506 | |
507 | TORCH_CHECK( |
508 | PyLong_Check(out.ptr()), |
509 | "dim returned invalid type " , |
510 | py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
511 | ", expected int" ); |
512 | |
513 | return THPUtils_unpackLong(out.ptr()); |
514 | } |
515 | |
516 | c10::Device ConcretePyInterpreterVTable::device( |
517 | const c10::TensorImpl* self) const { |
518 | pybind11::gil_scoped_acquire gil; |
519 | at::impl::MaybeSetTLSOnEntryGuard guard; |
520 | |
521 | auto out = torchDispatchFromTensorImpl( |
522 | self, |
523 | "device" , |
524 | py::module::import("torch" ) |
525 | .attr("ops" ) |
526 | .attr("prim" ) |
527 | .attr("device" ) |
528 | .attr("default" ) |
529 | .ptr(), |
530 | "torch.ops.prim" ); |
531 | |
532 | return toDevice(out.ptr()); |
533 | } |
534 | |
535 | c10::IntArrayRef ConcretePyInterpreterVTable::strides( |
536 | const c10::TensorImpl* self) const { |
537 | pybind11::gil_scoped_acquire gil; |
538 | at::impl::MaybeSetTLSOnEntryGuard guard; |
539 | |
540 | auto out = torchDispatchFromTensorImpl( |
541 | self, |
542 | "stride" , |
543 | py::module::import("torch" ) |
544 | .attr("ops" ) |
545 | .attr("aten" ) |
546 | .attr("stride" ) |
547 | .attr("default" ) |
548 | .ptr(), |
549 | "torch.ops.aten" ); |
550 | |
551 | if (out.is_none()) { |
552 | TORCH_CHECK( |
553 | !self->has_symbolic_sizes_strides(), |
554 | "Cannot call strides on a tensor with symbolic shapes/strides" ); |
555 | return self->strides_default(); |
556 | } |
557 | |
558 | py::object values = py::reinterpret_steal<py::object>(out.ptr()); |
559 | |
560 | c10::optional<PyObject*> mb_obj = |
561 | self->pyobj_slot()->check_pyobj(getPyInterpreter()); |
562 | TORCH_CHECK( |
563 | mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value" ); |
564 | PyObject* subclass = *mb_obj; |
565 | Py_INCREF(subclass); |
566 | py::object sub = py::reinterpret_steal<py::object>(subclass); |
567 | |
568 | py::object os = py::module_::import("torch" ).attr("overrides" ); |
569 | py::function get_buffer = |
570 | py::reinterpret_borrow<py::function>(os.attr("get_buffer" )); |
571 | auto buffer = get_buffer(sub, values, "stride" ); |
572 | auto result = THPUtils_unpackLongs(buffer.ptr()); |
573 | int64_t* start = (int64_t*)result[0]; |
574 | int64_t len = result[1]; |
575 | |
576 | return c10::IntArrayRef(start, len); |
577 | } |
578 | |
579 | static std::vector<int64_t> values_from_buffer( |
580 | const c10::TensorImpl* self, |
581 | py::handle values) { |
582 | c10::TensorImpl* ptr = const_cast<c10::TensorImpl*>(self); |
583 | c10::optional<PyObject*> mb_obj = |
584 | ptr->pyobj_slot()->check_pyobj(getPyInterpreter()); |
585 | TORCH_CHECK( |
586 | mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value" ); |
587 | |
588 | py::object os = py::module_::import("torch" ).attr("overrides" ); |
589 | py::function get_buffer = |
590 | py::reinterpret_borrow<py::function>(os.attr("get_buffer" )); |
591 | auto buffer = get_buffer(py::handle(*mb_obj), values, "size" ); |
592 | auto result = THPUtils_unpackLongs(buffer.ptr()); |
593 | return result; |
594 | } |
595 | |
596 | c10::IntArrayRef ConcretePyInterpreterVTable::sizes( |
597 | const c10::TensorImpl* self) const { |
598 | pybind11::gil_scoped_acquire gil; |
599 | at::impl::MaybeSetTLSOnEntryGuard guard; |
600 | |
601 | auto out = torchDispatchFromTensorImpl( |
602 | self, |
603 | "size" , |
604 | py::module::import("torch" ) |
605 | .attr("ops" ) |
606 | .attr("aten" ) |
607 | .attr("size" ) |
608 | .attr("default" ) |
609 | .ptr(), |
610 | "torch.ops.aten" ); |
611 | |
612 | if (out.is_none()) { |
613 | TORCH_CHECK( |
614 | !self->has_symbolic_sizes_strides(), |
615 | "Cannot call sizes on a tensor with symbolic shapes/strides" ); |
616 | return self->sizes_default(); |
617 | } |
618 | |
619 | py::object values = py::reinterpret_steal<py::object>(out.ptr()); |
620 | auto result = values_from_buffer(self, values); |
621 | int64_t* start = (int64_t*)result[0]; |
622 | int64_t len = result[1]; |
623 | |
624 | return c10::IntArrayRef(start, len); |
625 | } |
626 | |
627 | c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes( |
628 | const c10::TensorImpl* self) const { |
629 | pybind11::gil_scoped_acquire gil; |
630 | at::impl::MaybeSetTLSOnEntryGuard guard; |
631 | HANDLE_TH_ERRORS |
632 | auto out = torchDispatchFromTensorImpl( |
633 | self, |
634 | "sym_size" , |
635 | py::module::import("torch" ) |
636 | .attr("ops" ) |
637 | .attr("aten" ) |
638 | .attr("sym_size" ) |
639 | .attr("default" ) |
640 | .ptr(), |
641 | "torch.ops.aten" ); |
642 | |
643 | if (out.is_none()) { |
644 | return self->sym_sizes_default(); |
645 | } |
646 | // We need to squeeze SymIntNodes and ints into `SymInts` |
647 | // since it's a format `sym_sizes()` are stored in |
648 | TORCH_CHECK( |
649 | py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), |
650 | "Symshape must be a list or a tuple" ); |
651 | py::list symints; |
652 | for (auto it = out.begin(); it != out.end(); it++) { |
653 | auto elm = *it; |
654 | auto si = py::cast<c10::SymInt>(elm); |
655 | // TODO: the buffer will need to be made owning later |
656 | symints.append(si.as_int_unchecked()); |
657 | } |
658 | |
659 | auto result = values_from_buffer(self, symints); |
660 | c10::SymInt* start = (c10::SymInt*)result[0]; |
661 | int64_t len = result[1]; |
662 | |
663 | return c10::SymIntArrayRef(start, len); |
664 | END_HANDLE_TH_ERRORS_PYBIND |
665 | } |
666 | |
667 | c10::Layout ConcretePyInterpreterVTable::layout( |
668 | const c10::TensorImpl* self) const { |
669 | pybind11::gil_scoped_acquire gil; |
670 | at::impl::MaybeSetTLSOnEntryGuard guard; |
671 | auto out = torchDispatchFromTensorImpl( |
672 | self, |
673 | "layout" , |
674 | py::module::import("torch" ) |
675 | .attr("ops" ) |
676 | .attr("prim" ) |
677 | .attr("layout" ) |
678 | .attr("default" ) |
679 | .ptr(), |
680 | "torch.ops.prim" ); |
681 | |
682 | TORCH_CHECK( |
683 | THPLayout_Check(out.ptr()), |
684 | "layout returned invalid type " , |
685 | py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
686 | ", expected Layout" ); |
687 | |
688 | return toLayout(out.ptr()); |
689 | } |
690 | |
691 | c10::SymInt ConcretePyInterpreterVTable::sym_numel( |
692 | const c10::TensorImpl* self) const { |
693 | pybind11::gil_scoped_acquire gil; |
694 | at::impl::MaybeSetTLSOnEntryGuard guard; |
695 | auto out = torchDispatchFromTensorImpl( |
696 | self, |
697 | "sym_numel" , |
698 | py::module::import("torch" ) |
699 | .attr("ops" ) |
700 | .attr("aten" ) |
701 | .attr("sym_numel" ) |
702 | .attr("default" ) |
703 | .ptr(), |
704 | "torch.ops.aten" ); |
705 | |
706 | if (out.is_none()) { |
707 | TORCH_CHECK( |
708 | !self->has_symbolic_sizes_strides(), |
709 | "Cannot call numel on a tensor with symbolic shapes/strides" ); |
710 | return self->sym_numel_default(); |
711 | } |
712 | return torch::is_symint(out) ? out.cast<c10::SymInt>() |
713 | : c10::SymInt{py::cast<int64_t>(out)}; |
714 | } |
715 | |
716 | c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset( |
717 | const c10::TensorImpl* self) const { |
718 | pybind11::gil_scoped_acquire gil; |
719 | at::impl::MaybeSetTLSOnEntryGuard guard; |
720 | auto out = torchDispatchFromTensorImpl( |
721 | self, |
722 | "sym_storage_offset" , |
723 | py::module::import("torch" ) |
724 | .attr("ops" ) |
725 | .attr("aten" ) |
726 | .attr("sym_storage_offset" ) |
727 | .attr("default" ) |
728 | .ptr(), |
729 | "torch.ops.aten" ); |
730 | |
731 | if (out.is_none()) { |
732 | return self->sym_storage_offset_default(); |
733 | } |
734 | return torch::is_symint(out) ? out.cast<c10::SymInt>() |
735 | : c10::SymInt{py::cast<int64_t>(out)}; |
736 | } |
737 | |
738 | c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( |
739 | const c10::TensorImpl* self) const { |
740 | pybind11::gil_scoped_acquire gil; |
741 | at::impl::MaybeSetTLSOnEntryGuard guard; |
742 | HANDLE_TH_ERRORS |
743 | auto out = torchDispatchFromTensorImpl( |
744 | self, |
745 | "sym_stride" , |
746 | py::module::import("torch" ) |
747 | .attr("ops" ) |
748 | .attr("aten" ) |
749 | .attr("sym_stride" ) |
750 | .attr("default" ) |
751 | .ptr(), |
752 | "torch.ops.aten" ); |
753 | |
754 | if (out.is_none()) { |
755 | return self->sym_strides_default(); |
756 | } |
757 | // We need to squeeze SymIntNodes and ints into `SymInts` |
758 | // since it's a format `sym_strides()` are stored in |
759 | TORCH_CHECK( |
760 | py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), |
761 | "Symshape must be a list or a tuple" ); |
762 | py::list symints; |
763 | for (auto it = out.begin(); it != out.end(); it++) { |
764 | auto elm = *it; |
765 | auto si = torch::is_symint(elm) ? elm.cast<c10::SymInt>() |
766 | : c10::SymInt{py::cast<int64_t>(elm)}; |
767 | symints.append(si.as_int_unchecked()); |
768 | } |
769 | |
770 | auto result = values_from_buffer(self, symints); |
771 | c10::SymInt* start = (c10::SymInt*)result[0]; |
772 | int64_t len = result[1]; |
773 | |
774 | return c10::SymIntArrayRef(start, len); |
775 | END_HANDLE_TH_ERRORS_PYBIND |
776 | } |
777 | |
778 | PyInterpreterHolder self_interpreter; |
779 | |
780 | void ConcretePyInterpreterVTable::reset_backward_hooks( |
781 | const c10::TensorImpl* self) const { |
782 | pybind11::gil_scoped_acquire gil; |
783 | at::impl::MaybeSetTLSOnEntryGuard guard; |
784 | HANDLE_TH_ERRORS |
785 | Tensor self_t = Tensor( |
786 | c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>:: |
787 | unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self))); |
788 | auto self_p = |
789 | py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t))); |
790 | PyObject_SetAttrString(self_p.ptr(), "_backward_hooks" , Py_None); |
791 | END_HANDLE_TH_ERRORS_PYBIND |
792 | } |
793 | |
794 | } // anonymous namespace |
795 | |
796 | c10::impl::PyInterpreter* getPyInterpreter() { |
797 | return self_interpreter.get(); |
798 | } |
799 | |
800 | bool isMainPyInterpreter() { |
801 | return self_interpreter.is_main_interpreter(); |
802 | } |
803 | |
804 | std::string ConcretePyInterpreterVTable::name() const { |
805 | std::stringstream ss; |
806 | ss << getPyInterpreter(); |
807 | return ss.str(); |
808 | } |
809 | |