1 | #pragma once |
2 | |
3 | #include <c10/core/Backend.h> |
4 | #include <c10/core/ScalarType.h> |
5 | #include <c10/core/Layout.h> |
6 | #include <c10/core/TensorOptions.h> |
7 | #include <c10/core/Storage.h> |
8 | #include <ATen/core/DeprecatedTypePropertiesRegistry.h> |
9 | #include <ATen/core/Generator.h> |
10 | |
11 | |
12 | namespace at { |
13 | |
14 | class Tensor; |
15 | |
16 | // This class specifies a Backend and a ScalarType. Currently, it primarily |
17 | // serves as a replacement return value for Tensor::type(). Previously, |
18 | // Tensor::type() returned Type&, but we are changing Type to not be |
19 | // dtype-specific. |
20 | class TORCH_API DeprecatedTypeProperties { |
21 | public: |
22 | DeprecatedTypeProperties(Backend backend, ScalarType scalar_type) |
23 | : backend_(backend), scalar_type_(scalar_type) {} |
24 | |
25 | Backend backend() const { |
26 | return backend_; |
27 | } |
28 | |
29 | Layout layout() const { |
30 | return layout_from_backend(backend_); |
31 | } |
32 | |
33 | bool is_sparse() const { |
34 | return layout_from_backend(backend()) == kSparse; |
35 | } |
36 | |
37 | bool is_sparse_csr() const { |
38 | return layout_from_backend(backend()) == kSparseCsr; |
39 | } |
40 | |
41 | DeviceType device_type() const { |
42 | return backendToDeviceType(backend_); |
43 | } |
44 | |
45 | bool is_cuda() const { |
46 | return backendToDeviceType(backend_) == kCUDA; |
47 | } |
48 | |
49 | ScalarType scalarType() const { |
50 | return scalar_type_; |
51 | } |
52 | |
53 | caffe2::TypeMeta typeMeta() const { |
54 | return scalarTypeToTypeMeta(scalar_type_); |
55 | } |
56 | |
57 | bool operator==(const DeprecatedTypeProperties& other) const { |
58 | return backend_ == other.backend() && scalar_type_ == other.scalarType(); |
59 | } |
60 | |
61 | bool operator!=(const DeprecatedTypeProperties& other) const { |
62 | return !(*this == other); |
63 | } |
64 | |
65 | std::string toString() const { |
66 | std::string base_str; |
67 | if (backend_ == Backend::Undefined || scalar_type_ == ScalarType::Undefined) { |
68 | base_str = "UndefinedType" ; |
69 | } else { |
70 | base_str = std::string(at::toString(backend_)) + at::toString(scalar_type_) + "Type" ; |
71 | } |
72 | return base_str; |
73 | } |
74 | |
75 | DeprecatedTypeProperties & toBackend(Backend b) const { |
76 | return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
77 | b, scalar_type_); |
78 | } |
79 | |
80 | DeprecatedTypeProperties & toScalarType(ScalarType s) const { |
81 | return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
82 | backend_, s); |
83 | } |
84 | |
85 | DeprecatedTypeProperties & cpu() const { |
86 | return toBackend(Backend::CPU); |
87 | } |
88 | |
89 | DeprecatedTypeProperties & cuda() const { |
90 | return toBackend(Backend::CUDA); |
91 | } |
92 | |
93 | DeprecatedTypeProperties & hip() const { |
94 | return toBackend(Backend::HIP); |
95 | } |
96 | |
97 | /// Constructs the `TensorOptions` from a type and a `device_index`. |
98 | TensorOptions options(int16_t device_index = -1) const { |
99 | return TensorOptions().dtype(typeMeta()) |
100 | .device(device_type(), static_cast<c10::DeviceIndex>(device_index)) |
101 | .layout(layout()); |
102 | } |
103 | |
104 | /// Constructs the `TensorOptions` from a type and a Device. Asserts that |
105 | /// the device type matches the device type of the type. |
106 | TensorOptions options(c10::optional<Device> device_opt) const { |
107 | if (!device_opt.has_value()) { |
108 | return options(-1); |
109 | } else { |
110 | Device device = device_opt.value(); |
111 | AT_ASSERT(device.type() == device_type()); |
112 | return options(device.index()); |
113 | } |
114 | } |
115 | |
116 | operator TensorOptions() const { |
117 | return options(); |
118 | } |
119 | |
120 | int64_t id() const { |
121 | return static_cast<int64_t>(backend()) * |
122 | static_cast<int64_t>(ScalarType::NumOptions) + |
123 | static_cast<int64_t>(scalarType()); |
124 | } |
125 | |
126 | Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const; |
127 | Storage unsafeStorageFromTH(void * th_pointer, bool retain) const; |
128 | Tensor copy(const Tensor & src, bool non_blocking=false, c10::optional<Device> to_device={}) const; |
129 | |
130 | private: |
131 | Backend backend_; |
132 | ScalarType scalar_type_; |
133 | }; |
134 | |
135 | } // namespace at |
136 | |