1#include <torch/csrc/python_headers.h>
2
3#include <torch/csrc/Device.h>
4#include <torch/csrc/Dtype.h>
5#include <torch/csrc/DynamicTypes.h>
6#include <torch/csrc/Exceptions.h>
7#include <torch/csrc/Layout.h>
8#include <torch/csrc/PythonTypes.h>
9#include <torch/csrc/Storage.h>
10#include <torch/csrc/autograd/generated/VariableType.h>
11#include <torch/csrc/utils/cuda_enabled.h>
12#include <torch/csrc/utils/cuda_lazy_init.h>
13#include <torch/csrc/utils/object_ptr.h>
14
15#include <ATen/ATen.h>
16#include <ATen/FunctionalStorageImpl.h>
17
18#include <array>
19#include <memory>
20#include <sstream>
21#include <stdexcept>
22#include <string>
23#include <unordered_map>
24#include <vector>
25
26namespace torch {
27namespace {
28std::array<THPDtype*, static_cast<int>(at::ScalarType::NumOptions)>
29 dtype_registry = {};
30
31std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)>
32 layout_registry = {};
33
34at::DeprecatedTypeProperties* get_type_properties(
35 at::DeviceType device_type,
36 at::ScalarType scalarType) {
37 at::Backend backend;
38 if (device_type == at::kCPU) {
39 backend = at::Backend::CPU;
40 } else if (device_type == at::kCUDA) {
41 backend = at::Backend::CUDA;
42 } else if (device_type == at::kXPU) {
43 backend = at::Backend::XPU;
44 } else if (device_type == at::kMPS) {
45 backend = at::Backend::MPS;
46 } else if (device_type == at::DeviceType::Meta) {
47 backend = at::Backend::Undefined;
48 } else {
49 TORCH_CHECK(false, "Invalid device for storage: ", device_type);
50 }
51 return &at::getDeprecatedTypeProperties(backend, scalarType);
52}
53} // namespace
54
55void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) {
56 dtype_registry[static_cast<int>(scalarType)] = dtype;
57}
58
59void registerLayoutObject(THPLayout* thp_layout, at::Layout layout) {
60 layout_registry[static_cast<int>(layout)] = thp_layout;
61}
62
63THPDtype* getTHPDtype(at::ScalarType scalarType) {
64 auto dtype = dtype_registry[static_cast<int>(scalarType)];
65 if (!dtype) {
66 throw std::invalid_argument("unsupported scalarType");
67 }
68 return dtype;
69}
70
71THPLayout* getTHPLayout(at::Layout layout) {
72 auto thp_layout = layout_registry[static_cast<int>(layout)];
73 if (!thp_layout) {
74 throw std::invalid_argument("unsupported at::Layout");
75 }
76 return thp_layout;
77}
78
79PyObject* createPyObject(const at::Storage& storage) {
80 if (storage.device_type() != at::DeviceType::Meta &&
81 storage.data() == nullptr && storage.sym_nbytes() != 0 &&
82 // Grabbing storage() from FunctionalTensorWrapper is allowed.
83 // This is useful for checking aliasing info from python
84 dynamic_cast<at::functionalization::FunctionalStorageImpl*>(
85 storage.unsafeGetStorageImpl()) == nullptr) {
86 TORCH_CHECK_NOT_IMPLEMENTED(
87 false,
88 "python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details");
89 }
90 PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPStorageClass);
91 auto obj = THPObjectPtr(type->tp_alloc(type, 0));
92 if (!obj)
93 throw python_error();
94 ((THPVoidStorage*)obj.get())->cdata =
95 at::Storage(/* copy */ storage).unsafeReleaseStorageImpl();
96 return obj.release();
97}
98
99PyTypeObject* loadTypedStorageTypeObject() {
100 PyObject* storage_module = PyImport_ImportModule("torch.storage");
101 TORCH_INTERNAL_ASSERT(storage_module && PyModule_Check(storage_module));
102
103 PyObject* typed_storage_obj =
104 PyObject_GetAttrString(storage_module, "TypedStorage");
105 TORCH_INTERNAL_ASSERT(typed_storage_obj && PyType_Check(typed_storage_obj));
106 return reinterpret_cast<PyTypeObject*>(
107 PyObject_GetAttrString(storage_module, "TypedStorage"));
108}
109
110PyTypeObject* getTypedStorageTypeObject() {
111 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
112 static PyTypeObject* typed_storage_type_obj = loadTypedStorageTypeObject();
113 return typed_storage_type_obj;
114}
115
116bool isStorage(PyObject* obj) {
117 if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) {
118 return true;
119 }
120 auto obj_type = Py_TYPE(obj);
121
122 return obj_type == reinterpret_cast<PyTypeObject*>(THPStorageClass);
123}
124
125at::Storage createStorageGetType(
126 PyObject* obj,
127 at::ScalarType& scalar_type,
128 bool& is_typed_storage) {
129 is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
130 PyObject* untyped_storage_obj;
131
132 if (is_typed_storage) {
133 // NOTE: `PyObject_GetAttrString` increments the refcounts to `dtype` and
134 // `_untyped_storage`, so we must decrement them. The refcounts will still
135 // stay nonzero since the `TypedStorage` maintains a reference.
136 PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
137 TORCH_INTERNAL_ASSERT(dtype_obj);
138 Py_DECREF(dtype_obj);
139
140 TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
141 scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
142
143 untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
144 TORCH_INTERNAL_ASSERT(untyped_storage_obj);
145 Py_DECREF(untyped_storage_obj);
146
147 } else {
148 scalar_type = at::kByte;
149 untyped_storage_obj = obj;
150 }
151
152 if (Py_TYPE(untyped_storage_obj) !=
153 reinterpret_cast<PyTypeObject*>(THPStorageClass)) {
154 throw TypeError("not a storage '%s'", Py_TYPE(obj)->tp_name);
155 }
156
157 c10::StorageImpl* impl = static_cast<c10::StorageImpl*>(
158 ((THPVoidStorage*)untyped_storage_obj)->cdata);
159 c10::DeviceType device_type = impl->device().type();
160
161 auto type_properties = get_type_properties(device_type, at::kByte);
162
163 return type_properties->unsafeStorageFromTH(
164 ((THPVoidStorage*)untyped_storage_obj)->cdata, true);
165}
166
167at::Storage createStorage(PyObject* obj) {
168 at::ScalarType scalar_type;
169 bool is_typed_storage = false;
170 return createStorageGetType(obj, scalar_type, is_typed_storage);
171}
172
173} // namespace torch
174