1#pragma once
2
3#include <c10/core/Backend.h>
4#include <c10/core/ScalarType.h>
5#include <c10/core/Layout.h>
6#include <c10/core/TensorOptions.h>
7#include <c10/core/Storage.h>
8#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
9#include <ATen/core/Generator.h>
10
11
12namespace at {
13
14class Tensor;
15
16// This class specifies a Backend and a ScalarType. Currently, it primarily
17// serves as a replacement return value for Tensor::type(). Previously,
18// Tensor::type() returned Type&, but we are changing Type to not be
19// dtype-specific.
20class TORCH_API DeprecatedTypeProperties {
21 public:
22 DeprecatedTypeProperties(Backend backend, ScalarType scalar_type)
23 : backend_(backend), scalar_type_(scalar_type) {}
24
25 Backend backend() const {
26 return backend_;
27 }
28
29 Layout layout() const {
30 return layout_from_backend(backend_);
31 }
32
33 bool is_sparse() const {
34 return layout_from_backend(backend()) == kSparse;
35 }
36
37 bool is_sparse_csr() const {
38 return layout_from_backend(backend()) == kSparseCsr;
39 }
40
41 DeviceType device_type() const {
42 return backendToDeviceType(backend_);
43 }
44
45 bool is_cuda() const {
46 return backendToDeviceType(backend_) == kCUDA;
47 }
48
49 ScalarType scalarType() const {
50 return scalar_type_;
51 }
52
53 caffe2::TypeMeta typeMeta() const {
54 return scalarTypeToTypeMeta(scalar_type_);
55 }
56
57 bool operator==(const DeprecatedTypeProperties& other) const {
58 return backend_ == other.backend() && scalar_type_ == other.scalarType();
59 }
60
61 bool operator!=(const DeprecatedTypeProperties& other) const {
62 return !(*this == other);
63 }
64
65 std::string toString() const {
66 std::string base_str;
67 if (backend_ == Backend::Undefined || scalar_type_ == ScalarType::Undefined) {
68 base_str = "UndefinedType";
69 } else {
70 base_str = std::string(at::toString(backend_)) + at::toString(scalar_type_) + "Type";
71 }
72 return base_str;
73 }
74
75 DeprecatedTypeProperties & toBackend(Backend b) const {
76 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
77 b, scalar_type_);
78 }
79
80 DeprecatedTypeProperties & toScalarType(ScalarType s) const {
81 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
82 backend_, s);
83 }
84
85 DeprecatedTypeProperties & cpu() const {
86 return toBackend(Backend::CPU);
87 }
88
89 DeprecatedTypeProperties & cuda() const {
90 return toBackend(Backend::CUDA);
91 }
92
93 DeprecatedTypeProperties & hip() const {
94 return toBackend(Backend::HIP);
95 }
96
97 /// Constructs the `TensorOptions` from a type and a `device_index`.
98 TensorOptions options(int16_t device_index = -1) const {
99 return TensorOptions().dtype(typeMeta())
100 .device(device_type(), static_cast<c10::DeviceIndex>(device_index))
101 .layout(layout());
102 }
103
104 /// Constructs the `TensorOptions` from a type and a Device. Asserts that
105 /// the device type matches the device type of the type.
106 TensorOptions options(c10::optional<Device> device_opt) const {
107 if (!device_opt.has_value()) {
108 return options(-1);
109 } else {
110 Device device = device_opt.value();
111 AT_ASSERT(device.type() == device_type());
112 return options(device.index());
113 }
114 }
115
116 operator TensorOptions() const {
117 return options();
118 }
119
120 int64_t id() const {
121 return static_cast<int64_t>(backend()) *
122 static_cast<int64_t>(ScalarType::NumOptions) +
123 static_cast<int64_t>(scalarType());
124 }
125
126 Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const;
127 Storage unsafeStorageFromTH(void * th_pointer, bool retain) const;
128 Tensor copy(const Tensor & src, bool non_blocking=false, c10::optional<Device> to_device={}) const;
129
130 private:
131 Backend backend_;
132 ScalarType scalar_type_;
133};
134
135} // namespace at
136