1#include <ATen/core/PythonFallbackKernel.h>
2#include <ATen/core/PythonOpRegistrationTrampoline.h>
3#include <torch/csrc/PyInterpreter.h>
4#include <torch/csrc/THP.h>
5#include <torch/csrc/autograd/generated/VariableType.h>
6#include <torch/csrc/utils/python_arg_parser.h>
7#include <torch/csrc/utils/python_dispatch.h>
8
9#include <string>
10
11using namespace torch;
12using namespace at;
13using namespace c10;
14
15namespace {
16
17// NB: This is a macro and not a template function (like it was before)
18// because passing in constexpr char* as template argument breaks some
19// versions of MSVC that are being used internally at Meta.
20// MSVC 14.16.27023 (vs2017_15.9)
21#define CONCRETE_TRACE_CUDA(func_name, ...) \
22 at::impl::MaybeSetTLSOnEntryGuard guard; \
23 if (Py_IsInitialized()) { \
24 pybind11::gil_scoped_acquire gil; \
25 try { \
26 py::module mod = py::module::import("torch.utils._cuda_trace"); \
27 py::object hook = mod.attr(func_name).attr("fire_callbacks"); \
28 hook(__VA_ARGS__); \
29 } catch (const std::exception& e) { \
30 LOG(ERROR) << "CUDA trace hook execution failed: " << e.what(); \
31 } \
32 }
33
34struct ConcretePyInterpreterVTable final
35 : public c10::impl::PyInterpreterVTable {
36 std::string name() const override;
37
38 void decref(PyObject* pyobj, bool is_tensor) const override;
39
40 // TODO: Need to make this work for StorageImpl too. I imagine I'll want to
41 // operate upon a PyObjectSlot rather than a TensorImpl
42 c10::intrusive_ptr<c10::TensorImpl> detach(
43 const c10::TensorImpl* self) const override;
44
45 void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
46 const override;
47 void python_dispatcher(
48 const c10::OperatorHandle& op,
49 c10::DispatchKeySet,
50 torch::jit::Stack* stack) const override;
51 // NB: this is defined in python_dispatch.cpp
52 void python_op_registration_trampoline(
53 const c10::OperatorHandle& op,
54 c10::DispatchKey key,
55 torch::jit::Stack* stack) const override {
56 torch::impl::dispatch::python_op_registration_trampoline_impl(
57 op, key, stack);
58 }
59
60 bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
61 const override;
62 bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
63 const override;
64 bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
65 c10::Device device(const c10::TensorImpl* self) const override;
66 int64_t dim(const c10::TensorImpl* self) const override;
67 c10::IntArrayRef strides(const c10::TensorImpl* self) const override;
68 c10::IntArrayRef sizes(const c10::TensorImpl* self) const override;
69 c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override;
70 c10::Layout layout(const c10::TensorImpl* self) const override;
71 c10::SymInt sym_numel(const c10::TensorImpl* self) const override;
72 c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
73 c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
74
75 void trace_gpu_event_creation(uintptr_t event) const override {
76 CONCRETE_TRACE_CUDA("CUDAEventCreationCallbacks", event);
77 }
78 void trace_gpu_event_deletion(uintptr_t event) const override {
79 CONCRETE_TRACE_CUDA("CUDAEventDeletionCallbacks", event);
80 }
81 void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
82 const override {
83 CONCRETE_TRACE_CUDA("CUDAEventRecordCallbacks", event, stream);
84 }
85 void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {
86 CONCRETE_TRACE_CUDA("CUDAEventWaitCallbacks", event, stream);
87 }
88 void trace_gpu_memory_allocation(uintptr_t ptr) const override {
89 CONCRETE_TRACE_CUDA("CUDAMemoryAllocationCallbacks", ptr);
90 }
91 void trace_gpu_memory_deallocation(uintptr_t ptr) const override {
92 CONCRETE_TRACE_CUDA("CUDAMemoryDeallocationCallbacks", ptr);
93 }
94 void trace_gpu_stream_creation(uintptr_t stream) const override {
95 CONCRETE_TRACE_CUDA("CUDAStreamCreationCallbacks", stream);
96 }
97 void trace_gpu_device_synchronization() const override {
98 CONCRETE_TRACE_CUDA("CUDADeviceSynchronizationCallbacks");
99 }
100 void trace_gpu_stream_synchronization(uintptr_t stream) const override {
101 CONCRETE_TRACE_CUDA("CUDAStreamSynchronizationCallbacks", stream);
102 }
103 void trace_gpu_event_synchronization(uintptr_t event) const override {
104 CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks", event);
105 }
106
107 void reset_backward_hooks(const c10::TensorImpl* self) const override;
108
109 static ConcretePyInterpreterVTable* instance() {
110 static ConcretePyInterpreterVTable s;
111 return &s;
112 }
113};
114
115class PyInterpreterHolder {
116 public:
117 PyInterpreterHolder()
118 : impl_(new c10::impl::PyInterpreter(
119 ConcretePyInterpreterVTable::instance())) {
120 is_main_interpreter_ =
121 at::impl::PythonOpRegistrationTrampoline::registerInterpreter(impl_);
122 }
123 // NB: intentionally leaks the PyInterpreter, as there may still be
124 // references to it that are live, living in objects that aren't being
125 // destructed while Python is being cleaned up.
126 ~PyInterpreterHolder() {
127 impl_->disarm();
128 }
129 c10::impl::PyInterpreter* get() const noexcept {
130 return impl_;
131 }
132 bool is_main_interpreter() const noexcept {
133 return is_main_interpreter_;
134 }
135
136 private:
137 c10::impl::PyInterpreter* impl_;
138 bool is_main_interpreter_;
139};
140
141py::object torchDispatchFromTensorImpl(
142 const c10::TensorImpl* self,
143 const char* func_name,
144 PyObject* torch_api_function,
145 const char* module_name,
146 // WARNING: MUST NOT BE TENSOR ARGS
147 c10::SmallVector<py::object, 1> extra_args = {}) {
148 if (torch_api_function == nullptr) {
149 throw python_error();
150 }
151 TORCH_CHECK(
152 PyGILState_Check(),
153 "GIL must be held before you call parseIValuesToPyArgsKwargs");
154
155 std::vector<py::handle> overloaded_args;
156 // TODO: there should be a shorter way to spell this
157 // TODO: fix the constness of target
158 at::Tensor self_t = at::Tensor(
159 c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
160 unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
161 auto self_p =
162 py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
163 // NB: this may not be a python tensor if you got here from a mode!
164 // TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
165 append_overloaded_tensor(&overloaded_args, self_p.ptr());
166 auto args =
167 py::reinterpret_steal<py::object>(PyTuple_New(1 + extra_args.size()));
168 PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
169 int64_t i = 1;
170 for (auto& a : extra_args) {
171 if (a.ptr() == nullptr)
172 throw python_error();
173 PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr());
174 i++;
175 }
176
177 py::dict kwargs;
178
179 return py::reinterpret_steal<py::object>(
180 handle_torch_function_no_python_arg_parser(
181 overloaded_args,
182 args.ptr(),
183 kwargs.ptr(),
184 func_name,
185 torch_api_function,
186 module_name,
187 TorchFunctionName::TorchDispatch));
188}
189
190// NOTE [PyInterpreter::decref takes an `is_tensor` arg]
191// Before calling PyInterpreter::decref, we must statically know if the
192// pyobj is a Tensor or not.
193// - If it is a tensor, we need to be careful about PyObject resurrection
194// - If it is not a tensor, we can freely decref
195// One alternative to this is using PyObject_IsInstance
196// to get at this information. However, we don't want to risk an incorrect
197// `__instancecheck__` changing the semantics here.
198void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
199 const {
200 // Leak the pyobj if not initialized. This can happen if we are running
201 // exit handlers that are destructing tensors with residual (owned)
202 // PyObjects stored in them.
203 if (!Py_IsInitialized())
204 return;
205
206 pybind11::gil_scoped_acquire gil;
207 // Two possibilities:
208 // 1. We are decref-ing a tensor. Then we must be careful about
209 // PyObject resurrection (this only applies to Tensors, see
210 // THPVariable_clear).
211 // 2. We are decref-ing some other Python object. We don't do
212 // PyObject resurrection on non-Tensors, so we just carry on as usual
213 if (is_tensor && Py_REFCNT(pyobj) > 1) {
214 // It's still alive! This can happen if a weak ref resurrected
215 // the PyObject without flipping ownership. At this point it is
216 // too late to rescue the object, so just stub out the PyObject
217 // so that it fails on subsequent uses. Don't raise an error here;
218 // you're probably in a destructor.
219 TORCH_WARN(
220 "Deallocating Tensor that still has live PyObject references. "
221 "This probably happened because you took out a weak reference to "
222 "Tensor and didn't call _fix_weakref() after dereferencing it. "
223 "Subsequent accesses to this tensor via the PyObject will now fail.");
224 ((THPVariable*)pyobj)->cdata = c10::MaybeOwned<torch::autograd::Variable>();
225 }
226 Py_DECREF(pyobj);
227};
228
229py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
230 return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
231 // Parse the name into namespace and name (no overload_name)
232 // TODO: put this into the library
233 const auto& schema = op.schema();
234 const auto& qualified_name = op.operator_name().name;
235 const auto& overload_name = schema.overload_name();
236 auto pos = qualified_name.find("::");
237 TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
238 // Make me some null terminated strings
239 std::string ns_str = qualified_name.substr(0, pos);
240 const char* ns = ns_str.c_str();
241 const char* func_name = qualified_name.c_str() + pos + strlen("::");
242
243 py::handle torch_api_function =
244 py::module::import("torch").attr("ops").attr(ns).attr(func_name);
245 if (overload_name.empty()) {
246 return torch_api_function.attr("default").ptr();
247 } else {
248 return torch_api_function.attr(overload_name.c_str()).ptr();
249 }
250 });
251}
252
253bool isPythonTensor(const at::Tensor& tensor) {
254 return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
255}
256
257void ConcretePyInterpreterVTable::dispatch(
258 const c10::OperatorHandle& op,
259 torch::jit::Stack* stack) const {
260 const auto& schema = op.schema();
261 const auto num_arguments = schema.arguments().size();
262 auto arguments = torch::jit::pop(*stack, num_arguments);
263
264 // The plan: convert all the arguments back into PyObjects,
265 // extracting out the tensor handles, then call
266 // handle_torch_function_no_python_arg_parser
267 // NB: at the point arguments are pushed to the stack, ALL defaults
268 // are already present
269
270 py::gil_scoped_acquire g;
271
272 std::vector<py::handle> overloaded_args;
273 py::handle torch_api_function_overload = getTorchApiFunction(op);
274
275 // Find overloaded tensors
276 for (const auto idx : c10::irange(arguments.size())) {
277 const auto& ivalue = arguments[idx];
278 if (ivalue.isTensor()) {
279 const auto& tensor = ivalue.toTensor();
280 if (isPythonTensor(tensor)) {
281 append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
282 }
283 } else if (ivalue.isList()) {
284 const auto& list = ivalue.toListRef();
285 for (const auto jdx : c10::irange(list.size())) {
286 const auto& nv = list[jdx];
287 if (nv.isTensor()) {
288 const auto& tensor = nv.toTensor();
289 if (isPythonTensor(tensor)) {
290 append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
291 }
292 }
293 }
294 }
295 }
296
297 auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
298 auto args = std::move(args_kwargs.first);
299 auto kwargs = std::move(args_kwargs.second);
300
301 PyObject* obj = handle_torch_function_no_python_arg_parser(
302 overloaded_args,
303 args.ptr(),
304 kwargs.ptr(),
305 nullptr,
306 torch_api_function_overload.ptr(),
307 nullptr,
308 TorchFunctionName::TorchDispatch);
309 pushPyOutToStack(
310 op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
311}
312
313void ConcretePyInterpreterVTable::python_dispatcher(
314 const c10::OperatorHandle& op,
315 c10::DispatchKeySet ks,
316 torch::jit::Stack* stack) const {
317 py::gil_scoped_acquire g;
318 py::handle torch_api_function_overload = getTorchApiFunction(op);
319 // TODO: if necessary, can optimize to cache the cache lookup
320 // TODO: if necessary, can optimize OpOverload to have slots
321 auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache"));
322 if (cache.ptr() == nullptr) {
323 throw python_error();
324 }
325
326 c10::DispatchKey k = ks.highestPriorityTypeId();
327 // TODO: allow this to be non-owning
328 auto handler = py::reinterpret_borrow<py::object>(
329 PyDict_GetItem(cache.ptr(), py::cast(k).ptr()));
330 if (handler.ptr() == nullptr) {
331 // Slow path
332 handler = torch_api_function_overload.attr("_get_dispatch")(k);
333 }
334 if (py::isinstance<c10::DispatchKey>(handler)) {
335 // NB: not redispatch, as that will permanently remove the python
336 // dispatcher for subsequent redispatches
337 op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack);
338 return;
339 }
340
341 const auto& schema = op.schema();
342 const auto num_arguments = schema.arguments().size();
343 auto arguments = torch::jit::pop(*stack, num_arguments);
344
345 auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
346 auto args = std::move(args_kwargs.first);
347 auto kwargs = std::move(args_kwargs.second);
348
349 py::object obj = py::reinterpret_steal<py::object>(
350 PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr()));
351
352 if (obj.ptr() == nullptr) {
353 throw python_error();
354 }
355
356 pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher");
357}
358
359c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach(
360 const c10::TensorImpl* self) const {
361 pybind11::gil_scoped_acquire gil;
362 at::impl::MaybeSetTLSOnEntryGuard guard;
363
364 auto out = torchDispatchFromTensorImpl(
365 self,
366 "detach",
367 py::module::import("torch")
368 .attr("ops")
369 .attr("aten")
370 .attr("detach")
371 .attr("default")
372 .ptr(),
373 "torch.ops.aten");
374
375 TORCH_CHECK(
376 THPVariable_Check(out.ptr()),
377 "detach returned invalid type ",
378 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
379 ", expected Tensor");
380 const at::Tensor& res_t = THPVariable_Unpack(out.ptr());
381 return res_t.getIntrusivePtr();
382}
383
384bool ConcretePyInterpreterVTable::is_contiguous(
385 const c10::TensorImpl* self,
386 at::MemoryFormat memory_format) const {
387 pybind11::gil_scoped_acquire gil;
388 at::impl::MaybeSetTLSOnEntryGuard guard;
389
390 py::object out;
391 if (memory_format == at::MemoryFormat::Contiguous) {
392 // For backwards compatibility
393 out = torchDispatchFromTensorImpl(
394 self,
395 "is_contiguous",
396 py::module::import("torch")
397 .attr("ops")
398 .attr("aten")
399 .attr("is_contiguous")
400 .attr("default")
401 .ptr(),
402 "torch.ops.aten");
403 } else {
404 out = torchDispatchFromTensorImpl(
405 self,
406 "is_contiguous",
407 py::module::import("torch")
408 .attr("ops")
409 .attr("aten")
410 .attr("is_contiguous")
411 .attr("memory_format")
412 .ptr(),
413 "torch.ops.aten",
414 {py::cast(memory_format)});
415 }
416
417 if (out.is_none()) {
418 return self->is_contiguous_default(memory_format);
419 }
420
421 TORCH_CHECK(
422 PyBool_Check(out.ptr()),
423 "is_contiguous returned invalid type ",
424 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
425 ", expected bool");
426
427 return PyObject_IsTrue(out.ptr());
428}
429
430bool ConcretePyInterpreterVTable::is_strides_like(
431 const c10::TensorImpl* self,
432 at::MemoryFormat memory_format) const {
433 pybind11::gil_scoped_acquire gil;
434 at::impl::MaybeSetTLSOnEntryGuard guard;
435
436 auto out = torchDispatchFromTensorImpl(
437 self,
438 "is_strides_like",
439 py::module::import("torch")
440 .attr("ops")
441 .attr("aten")
442 // NB: intentionally suffixed with _format to avoid
443 // triggering matches against "_like" suffix
444 .attr("is_strides_like_format")
445 .attr("default")
446 .ptr(),
447 "torch.ops.aten",
448 {py::cast(memory_format)});
449
450 if (out.is_none()) {
451 return self->is_strides_like_default(memory_format);
452 }
453
454 TORCH_CHECK(
455 PyBool_Check(out.ptr()),
456 "is_strides_like_format returned invalid type ",
457 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
458 ", expected bool");
459
460 return PyObject_IsTrue(out.ptr());
461}
462
463bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense(
464 const c10::TensorImpl* self) const {
465 pybind11::gil_scoped_acquire gil;
466 at::impl::MaybeSetTLSOnEntryGuard guard;
467
468 auto out = torchDispatchFromTensorImpl(
469 self,
470 "is_non_overlapping_and_dense",
471 py::module::import("torch")
472 .attr("ops")
473 .attr("aten")
474 .attr("is_non_overlapping_and_dense")
475 .attr("default")
476 .ptr(),
477 "torch.ops.aten");
478
479 if (out.is_none()) {
480 return self->is_non_overlapping_and_dense_default();
481 }
482
483 TORCH_CHECK(
484 PyBool_Check(out.ptr()),
485 "is_non_overlapping_and_dense returned invalid type ",
486 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
487 ", expected bool");
488
489 return PyObject_IsTrue(out.ptr());
490}
491
492int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const {
493 pybind11::gil_scoped_acquire gil;
494 at::impl::MaybeSetTLSOnEntryGuard guard;
495
496 auto out = torchDispatchFromTensorImpl(
497 self,
498 "dim",
499 py::module::import("torch")
500 .attr("ops")
501 .attr("aten")
502 .attr("dim")
503 .attr("default")
504 .ptr(),
505 "torch.ops.aten");
506
507 TORCH_CHECK(
508 PyLong_Check(out.ptr()),
509 "dim returned invalid type ",
510 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
511 ", expected int");
512
513 return THPUtils_unpackLong(out.ptr());
514}
515
516c10::Device ConcretePyInterpreterVTable::device(
517 const c10::TensorImpl* self) const {
518 pybind11::gil_scoped_acquire gil;
519 at::impl::MaybeSetTLSOnEntryGuard guard;
520
521 auto out = torchDispatchFromTensorImpl(
522 self,
523 "device",
524 py::module::import("torch")
525 .attr("ops")
526 .attr("prim")
527 .attr("device")
528 .attr("default")
529 .ptr(),
530 "torch.ops.prim");
531
532 return toDevice(out.ptr());
533}
534
535c10::IntArrayRef ConcretePyInterpreterVTable::strides(
536 const c10::TensorImpl* self) const {
537 pybind11::gil_scoped_acquire gil;
538 at::impl::MaybeSetTLSOnEntryGuard guard;
539
540 auto out = torchDispatchFromTensorImpl(
541 self,
542 "stride",
543 py::module::import("torch")
544 .attr("ops")
545 .attr("aten")
546 .attr("stride")
547 .attr("default")
548 .ptr(),
549 "torch.ops.aten");
550
551 if (out.is_none()) {
552 TORCH_CHECK(
553 !self->has_symbolic_sizes_strides(),
554 "Cannot call strides on a tensor with symbolic shapes/strides");
555 return self->strides_default();
556 }
557
558 py::object values = py::reinterpret_steal<py::object>(out.ptr());
559
560 c10::optional<PyObject*> mb_obj =
561 self->pyobj_slot()->check_pyobj(getPyInterpreter());
562 TORCH_CHECK(
563 mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
564 PyObject* subclass = *mb_obj;
565 Py_INCREF(subclass);
566 py::object sub = py::reinterpret_steal<py::object>(subclass);
567
568 py::object os = py::module_::import("torch").attr("overrides");
569 py::function get_buffer =
570 py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
571 auto buffer = get_buffer(sub, values, "stride");
572 auto result = THPUtils_unpackLongs(buffer.ptr());
573 int64_t* start = (int64_t*)result[0];
574 int64_t len = result[1];
575
576 return c10::IntArrayRef(start, len);
577}
578
579static std::vector<int64_t> values_from_buffer(
580 const c10::TensorImpl* self,
581 py::handle values) {
582 c10::TensorImpl* ptr = const_cast<c10::TensorImpl*>(self);
583 c10::optional<PyObject*> mb_obj =
584 ptr->pyobj_slot()->check_pyobj(getPyInterpreter());
585 TORCH_CHECK(
586 mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
587
588 py::object os = py::module_::import("torch").attr("overrides");
589 py::function get_buffer =
590 py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
591 auto buffer = get_buffer(py::handle(*mb_obj), values, "size");
592 auto result = THPUtils_unpackLongs(buffer.ptr());
593 return result;
594}
595
596c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
597 const c10::TensorImpl* self) const {
598 pybind11::gil_scoped_acquire gil;
599 at::impl::MaybeSetTLSOnEntryGuard guard;
600
601 auto out = torchDispatchFromTensorImpl(
602 self,
603 "size",
604 py::module::import("torch")
605 .attr("ops")
606 .attr("aten")
607 .attr("size")
608 .attr("default")
609 .ptr(),
610 "torch.ops.aten");
611
612 if (out.is_none()) {
613 TORCH_CHECK(
614 !self->has_symbolic_sizes_strides(),
615 "Cannot call sizes on a tensor with symbolic shapes/strides");
616 return self->sizes_default();
617 }
618
619 py::object values = py::reinterpret_steal<py::object>(out.ptr());
620 auto result = values_from_buffer(self, values);
621 int64_t* start = (int64_t*)result[0];
622 int64_t len = result[1];
623
624 return c10::IntArrayRef(start, len);
625}
626
627c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes(
628 const c10::TensorImpl* self) const {
629 pybind11::gil_scoped_acquire gil;
630 at::impl::MaybeSetTLSOnEntryGuard guard;
631 HANDLE_TH_ERRORS
632 auto out = torchDispatchFromTensorImpl(
633 self,
634 "sym_size",
635 py::module::import("torch")
636 .attr("ops")
637 .attr("aten")
638 .attr("sym_size")
639 .attr("default")
640 .ptr(),
641 "torch.ops.aten");
642
643 if (out.is_none()) {
644 return self->sym_sizes_default();
645 }
646 // We need to squeeze SymIntNodes and ints into `SymInts`
647 // since it's a format `sym_sizes()` are stored in
648 TORCH_CHECK(
649 py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
650 "Symshape must be a list or a tuple");
651 py::list symints;
652 for (auto it = out.begin(); it != out.end(); it++) {
653 auto elm = *it;
654 auto si = py::cast<c10::SymInt>(elm);
655 // TODO: the buffer will need to be made owning later
656 symints.append(si.as_int_unchecked());
657 }
658
659 auto result = values_from_buffer(self, symints);
660 c10::SymInt* start = (c10::SymInt*)result[0];
661 int64_t len = result[1];
662
663 return c10::SymIntArrayRef(start, len);
664 END_HANDLE_TH_ERRORS_PYBIND
665}
666
667c10::Layout ConcretePyInterpreterVTable::layout(
668 const c10::TensorImpl* self) const {
669 pybind11::gil_scoped_acquire gil;
670 at::impl::MaybeSetTLSOnEntryGuard guard;
671 auto out = torchDispatchFromTensorImpl(
672 self,
673 "layout",
674 py::module::import("torch")
675 .attr("ops")
676 .attr("prim")
677 .attr("layout")
678 .attr("default")
679 .ptr(),
680 "torch.ops.prim");
681
682 TORCH_CHECK(
683 THPLayout_Check(out.ptr()),
684 "layout returned invalid type ",
685 py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
686 ", expected Layout");
687
688 return toLayout(out.ptr());
689}
690
691c10::SymInt ConcretePyInterpreterVTable::sym_numel(
692 const c10::TensorImpl* self) const {
693 pybind11::gil_scoped_acquire gil;
694 at::impl::MaybeSetTLSOnEntryGuard guard;
695 auto out = torchDispatchFromTensorImpl(
696 self,
697 "sym_numel",
698 py::module::import("torch")
699 .attr("ops")
700 .attr("aten")
701 .attr("sym_numel")
702 .attr("default")
703 .ptr(),
704 "torch.ops.aten");
705
706 if (out.is_none()) {
707 TORCH_CHECK(
708 !self->has_symbolic_sizes_strides(),
709 "Cannot call numel on a tensor with symbolic shapes/strides");
710 return self->sym_numel_default();
711 }
712 return torch::is_symint(out) ? out.cast<c10::SymInt>()
713 : c10::SymInt{py::cast<int64_t>(out)};
714}
715
716c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
717 const c10::TensorImpl* self) const {
718 pybind11::gil_scoped_acquire gil;
719 at::impl::MaybeSetTLSOnEntryGuard guard;
720 auto out = torchDispatchFromTensorImpl(
721 self,
722 "sym_storage_offset",
723 py::module::import("torch")
724 .attr("ops")
725 .attr("aten")
726 .attr("sym_storage_offset")
727 .attr("default")
728 .ptr(),
729 "torch.ops.aten");
730
731 if (out.is_none()) {
732 return self->sym_storage_offset_default();
733 }
734 return torch::is_symint(out) ? out.cast<c10::SymInt>()
735 : c10::SymInt{py::cast<int64_t>(out)};
736}
737
738c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
739 const c10::TensorImpl* self) const {
740 pybind11::gil_scoped_acquire gil;
741 at::impl::MaybeSetTLSOnEntryGuard guard;
742 HANDLE_TH_ERRORS
743 auto out = torchDispatchFromTensorImpl(
744 self,
745 "sym_stride",
746 py::module::import("torch")
747 .attr("ops")
748 .attr("aten")
749 .attr("sym_stride")
750 .attr("default")
751 .ptr(),
752 "torch.ops.aten");
753
754 if (out.is_none()) {
755 return self->sym_strides_default();
756 }
757 // We need to squeeze SymIntNodes and ints into `SymInts`
758 // since it's a format `sym_strides()` are stored in
759 TORCH_CHECK(
760 py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
761 "Symshape must be a list or a tuple");
762 py::list symints;
763 for (auto it = out.begin(); it != out.end(); it++) {
764 auto elm = *it;
765 auto si = torch::is_symint(elm) ? elm.cast<c10::SymInt>()
766 : c10::SymInt{py::cast<int64_t>(elm)};
767 symints.append(si.as_int_unchecked());
768 }
769
770 auto result = values_from_buffer(self, symints);
771 c10::SymInt* start = (c10::SymInt*)result[0];
772 int64_t len = result[1];
773
774 return c10::SymIntArrayRef(start, len);
775 END_HANDLE_TH_ERRORS_PYBIND
776}
777
778PyInterpreterHolder self_interpreter;
779
780void ConcretePyInterpreterVTable::reset_backward_hooks(
781 const c10::TensorImpl* self) const {
782 pybind11::gil_scoped_acquire gil;
783 at::impl::MaybeSetTLSOnEntryGuard guard;
784 HANDLE_TH_ERRORS
785 Tensor self_t = Tensor(
786 c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
787 unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
788 auto self_p =
789 py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t)));
790 PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None);
791 END_HANDLE_TH_ERRORS_PYBIND
792}
793
794} // anonymous namespace
795
796c10::impl::PyInterpreter* getPyInterpreter() {
797 return self_interpreter.get();
798}
799
800bool isMainPyInterpreter() {
801 return self_interpreter.is_main_interpreter();
802}
803
804std::string ConcretePyInterpreterVTable::name() const {
805 std::stringstream ss;
806 ss << getPyInterpreter();
807 return ss.str();
808}
809