1#include <c10/core/Scalar.h>
2#include <ATen/core/TensorBody.h>
3
4namespace at {
5
6#define DEFINE_CAST(T, name) \
7 template <> \
8 TORCH_API T* TensorBase::data_ptr() const { \
9 TORCH_CHECK( \
10 scalar_type() == ScalarType::name \
11 || (isQIntType(scalar_type()) \
12 && toUnderlying(scalar_type()) == ScalarType::name), \
13 "expected scalar type " \
14 #name \
15 " but found ", \
16 scalar_type()); \
17 return this->unsafeGetTensorImpl()->data_ptr_impl<T>(); \
18 }
19
20 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST)
21 AT_FORALL_QINT_TYPES(DEFINE_CAST)
22 #undef DEFINE_CAST
23
24 #define DEFINE_ITEM(T, name) \
25 template <> \
26 TORCH_API T Tensor::item() const { \
27 return item().to##name(); \
28 }
29
30 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ITEM)
31 #undef DEFINE_ITEM
32
33 } //namespace at
34