1 | #include <c10/core/Scalar.h> |
---|---|
2 | #include <ATen/core/TensorBody.h> |
3 | |
4 | namespace at { |
5 | |
6 | #define DEFINE_CAST(T, name) \ |
7 | template <> \ |
8 | TORCH_API T* TensorBase::data_ptr() const { \ |
9 | TORCH_CHECK( \ |
10 | scalar_type() == ScalarType::name \ |
11 | || (isQIntType(scalar_type()) \ |
12 | && toUnderlying(scalar_type()) == ScalarType::name), \ |
13 | "expected scalar type " \ |
14 | #name \ |
15 | " but found ", \ |
16 | scalar_type()); \ |
17 | return this->unsafeGetTensorImpl()->data_ptr_impl<T>(); \ |
18 | } |
19 | |
20 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST) |
21 | AT_FORALL_QINT_TYPES(DEFINE_CAST) |
22 | #undef DEFINE_CAST |
23 | |
24 | #define DEFINE_ITEM(T, name) \ |
25 | template <> \ |
26 | TORCH_API T Tensor::item() const { \ |
27 | return item().to##name(); \ |
28 | } |
29 | |
30 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ITEM) |
31 | #undef DEFINE_ITEM |
32 | |
33 | } //namespace at |
34 |