1#pragma once
2
3#include <c10/core/DeviceType.h>
4#include <c10/macros/Macros.h>
5#include <c10/util/Exception.h>
6
7#include <cstddef>
8#include <functional>
9#include <iosfwd>
10#include <string>
11
12namespace c10 {
13
14/// An index representing a specific device; e.g., the 1 in GPU 1.
15/// A DeviceIndex is not independently meaningful without knowing
16/// the DeviceType it is associated; try to use Device rather than
17/// DeviceIndex directly.
18using DeviceIndex = int8_t;
19
20/// Represents a a compute device on which a tensor is located. A device is
21/// uniquely identified by a type, which specifies the type of machine it is
22/// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the
23/// specific compute device when there is more than one of a certain type. The
24/// device index is optional, and in its defaulted state represents (abstractly)
25/// "the current device". Further, there are two constraints on the value of the
26/// device index, if one is explicitly stored:
27/// 1. A negative index represents the current device, a non-negative index
28/// represents a specific, concrete device,
29/// 2. When the device type is CPU, the device index must be zero.
30struct C10_API Device final {
31 using Type = DeviceType;
32
33 /// Constructs a new `Device` from a `DeviceType` and an optional device
34 /// index.
35 /* implicit */ Device(DeviceType type, DeviceIndex index = -1)
36 : type_(type), index_(index) {
37 validate();
38 }
39
40 /// Constructs a `Device` from a string description, for convenience.
41 /// The string supplied must follow the following schema:
42 /// `(cpu|cuda)[:<device-index>]`
43 /// where `cpu` or `cuda` specifies the device type, and
44 /// `:<device-index>` optionally specifies a device index.
45 /* implicit */ Device(const std::string& device_string);
46
47 /// Returns true if the type and index of this `Device` matches that of
48 /// `other`.
49 bool operator==(const Device& other) const noexcept {
50 return this->type_ == other.type_ && this->index_ == other.index_;
51 }
52
53 /// Returns true if the type or index of this `Device` differs from that of
54 /// `other`.
55 bool operator!=(const Device& other) const noexcept {
56 return !(*this == other);
57 }
58
59 /// Sets the device index.
60 void set_index(DeviceIndex index) {
61 index_ = index;
62 }
63
64 /// Returns the type of device this is.
65 DeviceType type() const noexcept {
66 return type_;
67 }
68
69 /// Returns the optional index.
70 DeviceIndex index() const noexcept {
71 return index_;
72 }
73
74 /// Returns true if the device has a non-default index.
75 bool has_index() const noexcept {
76 return index_ != -1;
77 }
78
79 /// Return true if the device is of CUDA type.
80 bool is_cuda() const noexcept {
81 return type_ == DeviceType::CUDA;
82 }
83
84 /// Return true if the device is of MPS type.
85 bool is_mps() const noexcept {
86 return type_ == DeviceType::MPS;
87 }
88
89 /// Return true if the device is of HIP type.
90 bool is_hip() const noexcept {
91 return type_ == DeviceType::HIP;
92 }
93
94 /// Return true if the device is of VE type.
95 bool is_ve() const noexcept {
96 return type_ == DeviceType::VE;
97 }
98
99 /// Return true if the device is of XPU type.
100 bool is_xpu() const noexcept {
101 return type_ == DeviceType::XPU;
102 }
103
104 /// Return true if the device is of IPU type.
105 bool is_ipu() const noexcept {
106 return type_ == DeviceType::IPU;
107 }
108
109 /// Return true if the device is of XLA type.
110 bool is_xla() const noexcept {
111 return type_ == DeviceType::XLA;
112 }
113
114 /// Return true if the device is of HPU type.
115 bool is_hpu() const noexcept {
116 return type_ == DeviceType::HPU;
117 }
118
119 /// Return true if the device is of Lazy type.
120 bool is_lazy() const noexcept {
121 return type_ == DeviceType::Lazy;
122 }
123
124 /// Return true if the device is of Vulkan type.
125 bool is_vulkan() const noexcept {
126 return type_ == DeviceType::Vulkan;
127 }
128
129 /// Return true if the device is of Metal type.
130 bool is_metal() const noexcept {
131 return type_ == DeviceType::Metal;
132 }
133
134 /// Return true if the device is of ORT type.
135 bool is_ort() const noexcept {
136 return type_ == DeviceType::ORT;
137 }
138
139 /// Return true if the device is of META type.
140 bool is_meta() const noexcept {
141 return type_ == DeviceType::Meta;
142 }
143
144 /// Return true if the device is of CPU type.
145 bool is_cpu() const noexcept {
146 return type_ == DeviceType::CPU;
147 }
148
149 /// Return true if the device supports arbirtary strides.
150 bool supports_as_strided() const noexcept {
151 return type_ != DeviceType::IPU && type_ != DeviceType::XLA &&
152 type_ != DeviceType::Lazy;
153 }
154
155 /// Same string as returned from operator<<.
156 std::string str() const;
157
158 private:
159 DeviceType type_;
160 DeviceIndex index_ = -1;
161 void validate() {
162 // Removing these checks in release builds noticeably improves
163 // performance in micro-benchmarks.
164 // This is safe to do, because backends that use the DeviceIndex
165 // have a later check when we actually try to switch to that device.
166 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
167 index_ == -1 || index_ >= 0,
168 "Device index must be -1 or non-negative, got ",
169 (int)index_);
170 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
171 !is_cpu() || index_ <= 0,
172 "CPU device index must be -1 or zero, got ",
173 (int)index_);
174 }
175};
176
177C10_API std::ostream& operator<<(std::ostream& stream, const Device& device);
178
179} // namespace c10
180
181namespace std {
182template <>
183struct hash<c10::Device> {
184 size_t operator()(c10::Device d) const noexcept {
185 // Are you here because this static assert failed? Make sure you ensure
186 // that the bitmasking code below is updated accordingly!
187 static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
188 static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
189 // Note [Hazard when concatenating signed integers]
190 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
191 // We must first convert to a same-sized unsigned type, before promoting to
192 // the result type, to prevent sign extension when any of the values is -1.
193 // If sign extension occurs, you'll clobber all of the values in the MSB
194 // half of the resulting integer.
195 //
196 // Technically, by C/C++ integer promotion rules, we only need one of the
197 // uint32_t casts to the result type, but we put in both for explicitness's
198 // sake.
199 uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
200 << 16 |
201 static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
202 return std::hash<uint32_t>{}(bits);
203 }
204};
205} // namespace std
206