1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/Tensor.h> |
5 | #include <ATen/dlpack.h> |
6 | |
7 | // this convertor will: |
8 | // 1) take a Tensor object and wrap it in the DLPack tensor |
9 | // 2) take a dlpack tensor and convert it to the ATen Tensor |
10 | |
11 | namespace at { |
12 | |
13 | TORCH_API ScalarType toScalarType(const DLDataType& dtype); |
14 | TORCH_API DLManagedTensor* toDLPack(const Tensor& src); |
15 | TORCH_API Tensor fromDLPack(const DLManagedTensor* src); |
16 | TORCH_API Tensor |
17 | fromDLPack(const DLManagedTensor* src, std::function<void(void*)> deleter); |
18 | TORCH_API DLDataType getDLDataType(const Tensor& t); |
19 | TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); |
20 | |
21 | } // namespace at |
22 |