1#include <ATen/DLConvertor.h>
2#include <ATen/Functions.h>
3
4#include <iostream>
5#include <sstream>
6
7using namespace std;
8namespace at {
9
10DLDataType getDLDataType(const Tensor& t) {
11 DLDataType dtype;
12 dtype.lanes = 1;
13 dtype.bits = t.element_size() * 8;
14 switch (t.scalar_type()) {
15 case ScalarType::Byte:
16 dtype.code = DLDataTypeCode::kDLUInt;
17 break;
18 case ScalarType::Char:
19 dtype.code = DLDataTypeCode::kDLInt;
20 break;
21 // NOLINTNEXTLINE(bugprone-branch-clone)
22 case ScalarType::Double:
23 dtype.code = DLDataTypeCode::kDLFloat;
24 break;
25 case ScalarType::Float:
26 dtype.code = DLDataTypeCode::kDLFloat;
27 break;
28 // NOLINTNEXTLINE(bugprone-branch-clone)
29 case ScalarType::Int:
30 dtype.code = DLDataTypeCode::kDLInt;
31 break;
32 case ScalarType::Long:
33 dtype.code = DLDataTypeCode::kDLInt;
34 break;
35 case ScalarType::Short:
36 dtype.code = DLDataTypeCode::kDLInt;
37 break;
38 case ScalarType::Half:
39 dtype.code = DLDataTypeCode::kDLFloat;
40 break;
41 case ScalarType::Bool:
42 TORCH_CHECK(false, "Bool type is not supported by dlpack");
43 break;
44 case ScalarType::ComplexHalf:
45 dtype.code = DLDataTypeCode::kDLComplex;
46 break;
47 case ScalarType::ComplexFloat:
48 dtype.code = DLDataTypeCode::kDLComplex;
49 break;
50 case ScalarType::ComplexDouble:
51 dtype.code = DLDataTypeCode::kDLComplex;
52 break;
53 case ScalarType::BFloat16:
54 dtype.code = DLDataTypeCode::kDLBfloat;
55 break;
56 case ScalarType::QInt8:
57 case ScalarType::QUInt8:
58 case ScalarType::QInt32:
59 case ScalarType::QUInt4x2:
60 case ScalarType::QUInt2x4:
61 TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
62 break;
63 case ScalarType::Undefined:
64 TORCH_CHECK(false, "Undefined is not a valid ScalarType");
65 case ScalarType::NumOptions:
66 TORCH_CHECK(false, "NumOptions is not a valid ScalarType");
67 }
68 return dtype;
69}
70
71DLDevice getDLDevice(const Tensor& tensor, const int64_t& device_id) {
72 DLDevice ctx;
73 ctx.device_id = device_id;
74 switch (tensor.device().type()) {
75 case DeviceType::CPU:
76 ctx.device_type = DLDeviceType::kDLCPU;
77 break;
78 case DeviceType::CUDA:
79#ifdef USE_ROCM
80 // ROCM, if enabled will look like cuda to PyTorch
81 // while everyone else should see HIP
82 ctx.device_type = DLDeviceType::kDLROCM;
83#else
84 ctx.device_type = DLDeviceType::kDLCUDA;
85#endif
86 break;
87 case DeviceType::OPENCL:
88 ctx.device_type = DLDeviceType::kDLOpenCL;
89 break;
90 case DeviceType::HIP:
91 ctx.device_type = DLDeviceType::kDLROCM;
92 break;
93 default:
94 TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str());
95 }
96 return ctx;
97}
98
99static Device getATenDevice(const DLDevice& ctx) {
100 switch (ctx.device_type) {
101 case DLDeviceType::kDLCPU:
102 return at::Device(DeviceType::CPU);
103#ifndef USE_ROCM
104 // if we are compiled under HIP, we cannot do cuda
105 case DLDeviceType::kDLCUDA:
106 return at::Device(DeviceType::CUDA, ctx.device_id);
107#endif
108 case DLDeviceType::kDLOpenCL:
109 return at::Device(DeviceType::OPENCL, ctx.device_id);
110 case DLDeviceType::kDLROCM:
111#ifdef USE_ROCM
112 // this looks funny, we need to return CUDA here to masquerade
113 return at::Device(DeviceType::CUDA, ctx.device_id);
114#else
115 return at::Device(DeviceType::HIP, ctx.device_id);
116#endif
117 default:
118 TORCH_CHECK(
119 false, "Unsupported device_type: " + c10::to_string(ctx.device_type));
120 }
121}
122
123ScalarType toScalarType(const DLDataType& dtype) {
124 ScalarType stype;
125 TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1");
126 switch (dtype.code) {
127 case DLDataTypeCode::kDLUInt:
128 switch (dtype.bits) {
129 case 8:
130 stype = ScalarType::Byte;
131 break;
132 default:
133 TORCH_CHECK(
134 false, "Unsupported kUInt bits " + c10::to_string(dtype.bits));
135 }
136 break;
137 case DLDataTypeCode::kDLInt:
138 switch (dtype.bits) {
139 case 8:
140 stype = ScalarType::Char;
141 break;
142 case 16:
143 stype = ScalarType::Short;
144 break;
145 case 32:
146 stype = ScalarType::Int;
147 break;
148 case 64:
149 stype = ScalarType::Long;
150 break;
151 default:
152 TORCH_CHECK(
153 false, "Unsupported kInt bits " + c10::to_string(dtype.bits));
154 }
155 break;
156 case DLDataTypeCode::kDLFloat:
157 switch (dtype.bits) {
158 case 16:
159 stype = ScalarType::Half;
160 break;
161 case 32:
162 stype = ScalarType::Float;
163 break;
164 case 64:
165 stype = ScalarType::Double;
166 break;
167 default:
168 TORCH_CHECK(
169 false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
170 }
171 break;
172 case DLDataTypeCode::kDLBfloat:
173 switch (dtype.bits) {
174 case 16:
175 stype = ScalarType::BFloat16;
176 break;
177 default:
178 TORCH_CHECK(
179 false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
180 }
181 break;
182 case DLDataTypeCode::kDLComplex:
183 switch (dtype.bits) {
184 case 32:
185 stype = ScalarType::ComplexHalf;
186 break;
187 case 64:
188 stype = ScalarType::ComplexFloat;
189 break;
190 case 128:
191 stype = ScalarType::ComplexDouble;
192 break;
193 default:
194 TORCH_CHECK(
195 false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
196 }
197 break;
198 default:
199 TORCH_CHECK(
200 false, "Unsupported code " + c10::to_string(dtype.code));
201 }
202 return stype;
203}
204
205// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
206struct ATenDLMTensor {
207 Tensor handle;
208 DLManagedTensor tensor;
209};
210
211void deleter(DLManagedTensor* arg) {
212 delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
213}
214
215// This function returns a shared_ptr to memory managed DLpack tensor
216// constructed out of ATen tensor
217DLManagedTensor* toDLPack(const Tensor& src) {
218 // create a new tensor with possibly normalized strides
219 // gh-83069
220 auto shape = src.sizes();
221 auto strides = src.strides().vec();
222 for (int i=0; i<src.dim(); i++) {
223 if (shape[i] < 2) {
224 strides[i] = 1;
225 }
226 }
227
228 auto view = src.as_strided(shape, strides, src.storage_offset());
229 ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
230 atDLMTensor->handle = view;
231 atDLMTensor->tensor.manager_ctx = atDLMTensor;
232 atDLMTensor->tensor.deleter = &deleter;
233 atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
234 int64_t device_id = 0;
235 if (src.is_cuda()) {
236 device_id = src.get_device();
237 }
238 atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
239 atDLMTensor->tensor.dl_tensor.ndim = src.dim();
240 atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
241 atDLMTensor->tensor.dl_tensor.shape =
242 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
243 const_cast<int64_t*>(view.sizes().data());
244 atDLMTensor->tensor.dl_tensor.strides =
245 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
246 const_cast<int64_t*>(view.strides().data());
247 atDLMTensor->tensor.dl_tensor.byte_offset = 0;
248 return &(atDLMTensor->tensor);
249}
250
251Tensor fromDLPack(const DLManagedTensor* src) {
252 auto deleter = [src](void* self) {
253 if (src->deleter) {
254 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
255 src->deleter(const_cast<DLManagedTensor*>(src));
256 }
257 };
258 return fromDLPack(src, std::move(deleter));
259}
260
261Tensor fromDLPack(
262 const DLManagedTensor* src,
263 std::function<void(void*)> deleter) {
264 Device device = getATenDevice(src->dl_tensor.device);
265 ScalarType stype = toScalarType(src->dl_tensor.dtype);
266 if (!src->dl_tensor.strides) {
267 return at::from_blob(src->dl_tensor.data,
268 IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
269 deleter,
270 at::device(device).dtype(stype));
271 }
272 return at::from_blob(
273 src->dl_tensor.data,
274 IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
275 IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
276 deleter,
277 at::device(device).dtype(stype),
278 { device });
279}
280} // namespace at
281