1 | #include <ATen/DLConvertor.h> |
2 | #include <ATen/Functions.h> |
3 | |
4 | #include <iostream> |
5 | #include <sstream> |
6 | |
7 | using namespace std; |
8 | namespace at { |
9 | |
10 | DLDataType getDLDataType(const Tensor& t) { |
11 | DLDataType dtype; |
12 | dtype.lanes = 1; |
13 | dtype.bits = t.element_size() * 8; |
14 | switch (t.scalar_type()) { |
15 | case ScalarType::Byte: |
16 | dtype.code = DLDataTypeCode::kDLUInt; |
17 | break; |
18 | case ScalarType::Char: |
19 | dtype.code = DLDataTypeCode::kDLInt; |
20 | break; |
21 | // NOLINTNEXTLINE(bugprone-branch-clone) |
22 | case ScalarType::Double: |
23 | dtype.code = DLDataTypeCode::kDLFloat; |
24 | break; |
25 | case ScalarType::Float: |
26 | dtype.code = DLDataTypeCode::kDLFloat; |
27 | break; |
28 | // NOLINTNEXTLINE(bugprone-branch-clone) |
29 | case ScalarType::Int: |
30 | dtype.code = DLDataTypeCode::kDLInt; |
31 | break; |
32 | case ScalarType::Long: |
33 | dtype.code = DLDataTypeCode::kDLInt; |
34 | break; |
35 | case ScalarType::Short: |
36 | dtype.code = DLDataTypeCode::kDLInt; |
37 | break; |
38 | case ScalarType::Half: |
39 | dtype.code = DLDataTypeCode::kDLFloat; |
40 | break; |
41 | case ScalarType::Bool: |
42 | TORCH_CHECK(false, "Bool type is not supported by dlpack" ); |
43 | break; |
44 | case ScalarType::ComplexHalf: |
45 | dtype.code = DLDataTypeCode::kDLComplex; |
46 | break; |
47 | case ScalarType::ComplexFloat: |
48 | dtype.code = DLDataTypeCode::kDLComplex; |
49 | break; |
50 | case ScalarType::ComplexDouble: |
51 | dtype.code = DLDataTypeCode::kDLComplex; |
52 | break; |
53 | case ScalarType::BFloat16: |
54 | dtype.code = DLDataTypeCode::kDLBfloat; |
55 | break; |
56 | case ScalarType::QInt8: |
57 | case ScalarType::QUInt8: |
58 | case ScalarType::QInt32: |
59 | case ScalarType::QUInt4x2: |
60 | case ScalarType::QUInt2x4: |
61 | TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack" ); |
62 | break; |
63 | case ScalarType::Undefined: |
64 | TORCH_CHECK(false, "Undefined is not a valid ScalarType" ); |
65 | case ScalarType::NumOptions: |
66 | TORCH_CHECK(false, "NumOptions is not a valid ScalarType" ); |
67 | } |
68 | return dtype; |
69 | } |
70 | |
71 | DLDevice getDLDevice(const Tensor& tensor, const int64_t& device_id) { |
72 | DLDevice ctx; |
73 | ctx.device_id = device_id; |
74 | switch (tensor.device().type()) { |
75 | case DeviceType::CPU: |
76 | ctx.device_type = DLDeviceType::kDLCPU; |
77 | break; |
78 | case DeviceType::CUDA: |
79 | #ifdef USE_ROCM |
80 | // ROCM, if enabled will look like cuda to PyTorch |
81 | // while everyone else should see HIP |
82 | ctx.device_type = DLDeviceType::kDLROCM; |
83 | #else |
84 | ctx.device_type = DLDeviceType::kDLCUDA; |
85 | #endif |
86 | break; |
87 | case DeviceType::OPENCL: |
88 | ctx.device_type = DLDeviceType::kDLOpenCL; |
89 | break; |
90 | case DeviceType::HIP: |
91 | ctx.device_type = DLDeviceType::kDLROCM; |
92 | break; |
93 | default: |
94 | TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str()); |
95 | } |
96 | return ctx; |
97 | } |
98 | |
99 | static Device getATenDevice(const DLDevice& ctx) { |
100 | switch (ctx.device_type) { |
101 | case DLDeviceType::kDLCPU: |
102 | return at::Device(DeviceType::CPU); |
103 | #ifndef USE_ROCM |
104 | // if we are compiled under HIP, we cannot do cuda |
105 | case DLDeviceType::kDLCUDA: |
106 | return at::Device(DeviceType::CUDA, ctx.device_id); |
107 | #endif |
108 | case DLDeviceType::kDLOpenCL: |
109 | return at::Device(DeviceType::OPENCL, ctx.device_id); |
110 | case DLDeviceType::kDLROCM: |
111 | #ifdef USE_ROCM |
112 | // this looks funny, we need to return CUDA here to masquerade |
113 | return at::Device(DeviceType::CUDA, ctx.device_id); |
114 | #else |
115 | return at::Device(DeviceType::HIP, ctx.device_id); |
116 | #endif |
117 | default: |
118 | TORCH_CHECK( |
119 | false, "Unsupported device_type: " + c10::to_string(ctx.device_type)); |
120 | } |
121 | } |
122 | |
123 | ScalarType toScalarType(const DLDataType& dtype) { |
124 | ScalarType stype; |
125 | TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1" ); |
126 | switch (dtype.code) { |
127 | case DLDataTypeCode::kDLUInt: |
128 | switch (dtype.bits) { |
129 | case 8: |
130 | stype = ScalarType::Byte; |
131 | break; |
132 | default: |
133 | TORCH_CHECK( |
134 | false, "Unsupported kUInt bits " + c10::to_string(dtype.bits)); |
135 | } |
136 | break; |
137 | case DLDataTypeCode::kDLInt: |
138 | switch (dtype.bits) { |
139 | case 8: |
140 | stype = ScalarType::Char; |
141 | break; |
142 | case 16: |
143 | stype = ScalarType::Short; |
144 | break; |
145 | case 32: |
146 | stype = ScalarType::Int; |
147 | break; |
148 | case 64: |
149 | stype = ScalarType::Long; |
150 | break; |
151 | default: |
152 | TORCH_CHECK( |
153 | false, "Unsupported kInt bits " + c10::to_string(dtype.bits)); |
154 | } |
155 | break; |
156 | case DLDataTypeCode::kDLFloat: |
157 | switch (dtype.bits) { |
158 | case 16: |
159 | stype = ScalarType::Half; |
160 | break; |
161 | case 32: |
162 | stype = ScalarType::Float; |
163 | break; |
164 | case 64: |
165 | stype = ScalarType::Double; |
166 | break; |
167 | default: |
168 | TORCH_CHECK( |
169 | false, "Unsupported kFloat bits " + c10::to_string(dtype.bits)); |
170 | } |
171 | break; |
172 | case DLDataTypeCode::kDLBfloat: |
173 | switch (dtype.bits) { |
174 | case 16: |
175 | stype = ScalarType::BFloat16; |
176 | break; |
177 | default: |
178 | TORCH_CHECK( |
179 | false, "Unsupported kFloat bits " + c10::to_string(dtype.bits)); |
180 | } |
181 | break; |
182 | case DLDataTypeCode::kDLComplex: |
183 | switch (dtype.bits) { |
184 | case 32: |
185 | stype = ScalarType::ComplexHalf; |
186 | break; |
187 | case 64: |
188 | stype = ScalarType::ComplexFloat; |
189 | break; |
190 | case 128: |
191 | stype = ScalarType::ComplexDouble; |
192 | break; |
193 | default: |
194 | TORCH_CHECK( |
195 | false, "Unsupported kFloat bits " + c10::to_string(dtype.bits)); |
196 | } |
197 | break; |
198 | default: |
199 | TORCH_CHECK( |
200 | false, "Unsupported code " + c10::to_string(dtype.code)); |
201 | } |
202 | return stype; |
203 | } |
204 | |
205 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
206 | struct ATenDLMTensor { |
207 | Tensor handle; |
208 | DLManagedTensor tensor; |
209 | }; |
210 | |
211 | void deleter(DLManagedTensor* arg) { |
212 | delete static_cast<ATenDLMTensor*>(arg->manager_ctx); |
213 | } |
214 | |
215 | // This function returns a shared_ptr to memory managed DLpack tensor |
216 | // constructed out of ATen tensor |
217 | DLManagedTensor* toDLPack(const Tensor& src) { |
218 | // create a new tensor with possibly normalized strides |
219 | // gh-83069 |
220 | auto shape = src.sizes(); |
221 | auto strides = src.strides().vec(); |
222 | for (int i=0; i<src.dim(); i++) { |
223 | if (shape[i] < 2) { |
224 | strides[i] = 1; |
225 | } |
226 | } |
227 | |
228 | auto view = src.as_strided(shape, strides, src.storage_offset()); |
229 | ATenDLMTensor* atDLMTensor(new ATenDLMTensor); |
230 | atDLMTensor->handle = view; |
231 | atDLMTensor->tensor.manager_ctx = atDLMTensor; |
232 | atDLMTensor->tensor.deleter = &deleter; |
233 | atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); |
234 | int64_t device_id = 0; |
235 | if (src.is_cuda()) { |
236 | device_id = src.get_device(); |
237 | } |
238 | atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id); |
239 | atDLMTensor->tensor.dl_tensor.ndim = src.dim(); |
240 | atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); |
241 | atDLMTensor->tensor.dl_tensor.shape = |
242 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
243 | const_cast<int64_t*>(view.sizes().data()); |
244 | atDLMTensor->tensor.dl_tensor.strides = |
245 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
246 | const_cast<int64_t*>(view.strides().data()); |
247 | atDLMTensor->tensor.dl_tensor.byte_offset = 0; |
248 | return &(atDLMTensor->tensor); |
249 | } |
250 | |
251 | Tensor fromDLPack(const DLManagedTensor* src) { |
252 | auto deleter = [src](void* self) { |
253 | if (src->deleter) { |
254 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
255 | src->deleter(const_cast<DLManagedTensor*>(src)); |
256 | } |
257 | }; |
258 | return fromDLPack(src, std::move(deleter)); |
259 | } |
260 | |
261 | Tensor fromDLPack( |
262 | const DLManagedTensor* src, |
263 | std::function<void(void*)> deleter) { |
264 | Device device = getATenDevice(src->dl_tensor.device); |
265 | ScalarType stype = toScalarType(src->dl_tensor.dtype); |
266 | if (!src->dl_tensor.strides) { |
267 | return at::from_blob(src->dl_tensor.data, |
268 | IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), |
269 | deleter, |
270 | at::device(device).dtype(stype)); |
271 | } |
272 | return at::from_blob( |
273 | src->dl_tensor.data, |
274 | IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), |
275 | IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim), |
276 | deleter, |
277 | at::device(device).dtype(stype), |
278 | { device }); |
279 | } |
280 | } // namespace at |
281 | |