1 | #include <torch/csrc/python_headers.h> |
2 | |
3 | #include <torch/csrc/Device.h> |
4 | #include <torch/csrc/Dtype.h> |
5 | #include <torch/csrc/DynamicTypes.h> |
6 | #include <torch/csrc/Exceptions.h> |
7 | #include <torch/csrc/Layout.h> |
8 | #include <torch/csrc/PythonTypes.h> |
9 | #include <torch/csrc/Storage.h> |
10 | #include <torch/csrc/autograd/generated/VariableType.h> |
11 | #include <torch/csrc/utils/cuda_enabled.h> |
12 | #include <torch/csrc/utils/cuda_lazy_init.h> |
13 | #include <torch/csrc/utils/object_ptr.h> |
14 | |
15 | #include <ATen/ATen.h> |
16 | #include <ATen/FunctionalStorageImpl.h> |
17 | |
18 | #include <array> |
19 | #include <memory> |
20 | #include <sstream> |
21 | #include <stdexcept> |
22 | #include <string> |
23 | #include <unordered_map> |
24 | #include <vector> |
25 | |
26 | namespace torch { |
27 | namespace { |
28 | std::array<THPDtype*, static_cast<int>(at::ScalarType::NumOptions)> |
29 | dtype_registry = {}; |
30 | |
31 | std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)> |
32 | layout_registry = {}; |
33 | |
34 | at::DeprecatedTypeProperties* get_type_properties( |
35 | at::DeviceType device_type, |
36 | at::ScalarType scalarType) { |
37 | at::Backend backend; |
38 | if (device_type == at::kCPU) { |
39 | backend = at::Backend::CPU; |
40 | } else if (device_type == at::kCUDA) { |
41 | backend = at::Backend::CUDA; |
42 | } else if (device_type == at::kXPU) { |
43 | backend = at::Backend::XPU; |
44 | } else if (device_type == at::kMPS) { |
45 | backend = at::Backend::MPS; |
46 | } else if (device_type == at::DeviceType::Meta) { |
47 | backend = at::Backend::Undefined; |
48 | } else { |
49 | TORCH_CHECK(false, "Invalid device for storage: " , device_type); |
50 | } |
51 | return &at::getDeprecatedTypeProperties(backend, scalarType); |
52 | } |
53 | } // namespace |
54 | |
55 | void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) { |
56 | dtype_registry[static_cast<int>(scalarType)] = dtype; |
57 | } |
58 | |
59 | void registerLayoutObject(THPLayout* thp_layout, at::Layout layout) { |
60 | layout_registry[static_cast<int>(layout)] = thp_layout; |
61 | } |
62 | |
63 | THPDtype* getTHPDtype(at::ScalarType scalarType) { |
64 | auto dtype = dtype_registry[static_cast<int>(scalarType)]; |
65 | if (!dtype) { |
66 | throw std::invalid_argument("unsupported scalarType" ); |
67 | } |
68 | return dtype; |
69 | } |
70 | |
71 | THPLayout* getTHPLayout(at::Layout layout) { |
72 | auto thp_layout = layout_registry[static_cast<int>(layout)]; |
73 | if (!thp_layout) { |
74 | throw std::invalid_argument("unsupported at::Layout" ); |
75 | } |
76 | return thp_layout; |
77 | } |
78 | |
79 | PyObject* createPyObject(const at::Storage& storage) { |
80 | if (storage.device_type() != at::DeviceType::Meta && |
81 | storage.data() == nullptr && storage.sym_nbytes() != 0 && |
82 | // Grabbing storage() from FunctionalTensorWrapper is allowed. |
83 | // This is useful for checking aliasing info from python |
84 | dynamic_cast<at::functionalization::FunctionalStorageImpl*>( |
85 | storage.unsafeGetStorageImpl()) == nullptr) { |
86 | TORCH_CHECK_NOT_IMPLEMENTED( |
87 | false, |
88 | "python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details" ); |
89 | } |
90 | PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPStorageClass); |
91 | auto obj = THPObjectPtr(type->tp_alloc(type, 0)); |
92 | if (!obj) |
93 | throw python_error(); |
94 | ((THPVoidStorage*)obj.get())->cdata = |
95 | at::Storage(/* copy */ storage).unsafeReleaseStorageImpl(); |
96 | return obj.release(); |
97 | } |
98 | |
99 | PyTypeObject* loadTypedStorageTypeObject() { |
100 | PyObject* storage_module = PyImport_ImportModule("torch.storage" ); |
101 | TORCH_INTERNAL_ASSERT(storage_module && PyModule_Check(storage_module)); |
102 | |
103 | PyObject* typed_storage_obj = |
104 | PyObject_GetAttrString(storage_module, "TypedStorage" ); |
105 | TORCH_INTERNAL_ASSERT(typed_storage_obj && PyType_Check(typed_storage_obj)); |
106 | return reinterpret_cast<PyTypeObject*>( |
107 | PyObject_GetAttrString(storage_module, "TypedStorage" )); |
108 | } |
109 | |
110 | PyTypeObject* getTypedStorageTypeObject() { |
111 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
112 | static PyTypeObject* typed_storage_type_obj = loadTypedStorageTypeObject(); |
113 | return typed_storage_type_obj; |
114 | } |
115 | |
116 | bool isStorage(PyObject* obj) { |
117 | if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) { |
118 | return true; |
119 | } |
120 | auto obj_type = Py_TYPE(obj); |
121 | |
122 | return obj_type == reinterpret_cast<PyTypeObject*>(THPStorageClass); |
123 | } |
124 | |
125 | at::Storage createStorageGetType( |
126 | PyObject* obj, |
127 | at::ScalarType& scalar_type, |
128 | bool& is_typed_storage) { |
129 | is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject()); |
130 | PyObject* untyped_storage_obj; |
131 | |
132 | if (is_typed_storage) { |
133 | // NOTE: `PyObject_GetAttrString` increments the refcounts to `dtype` and |
134 | // `_untyped_storage`, so we must decrement them. The refcounts will still |
135 | // stay nonzero since the `TypedStorage` maintains a reference. |
136 | PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype" ); |
137 | TORCH_INTERNAL_ASSERT(dtype_obj); |
138 | Py_DECREF(dtype_obj); |
139 | |
140 | TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj)); |
141 | scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type; |
142 | |
143 | untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage" ); |
144 | TORCH_INTERNAL_ASSERT(untyped_storage_obj); |
145 | Py_DECREF(untyped_storage_obj); |
146 | |
147 | } else { |
148 | scalar_type = at::kByte; |
149 | untyped_storage_obj = obj; |
150 | } |
151 | |
152 | if (Py_TYPE(untyped_storage_obj) != |
153 | reinterpret_cast<PyTypeObject*>(THPStorageClass)) { |
154 | throw TypeError("not a storage '%s'" , Py_TYPE(obj)->tp_name); |
155 | } |
156 | |
157 | c10::StorageImpl* impl = static_cast<c10::StorageImpl*>( |
158 | ((THPVoidStorage*)untyped_storage_obj)->cdata); |
159 | c10::DeviceType device_type = impl->device().type(); |
160 | |
161 | auto type_properties = get_type_properties(device_type, at::kByte); |
162 | |
163 | return type_properties->unsafeStorageFromTH( |
164 | ((THPVoidStorage*)untyped_storage_obj)->cdata, true); |
165 | } |
166 | |
167 | at::Storage createStorage(PyObject* obj) { |
168 | at::ScalarType scalar_type; |
169 | bool is_typed_storage = false; |
170 | return createStorageGetType(obj, scalar_type, is_typed_storage); |
171 | } |
172 | |
173 | } // namespace torch |
174 | |