1 | #pragma once |
---|---|
2 | |
3 | #include <memory> |
4 | #include <ostream> |
5 | #include <string> |
6 | |
7 | #include <ATen/Tensor.h> |
8 | #include <c10/macros/Export.h> |
9 | #include <c10/util/Deprecated.h> |
10 | #include <c10/util/Optional.h> |
11 | |
12 | namespace c10 { |
13 | struct Device; |
14 | } |
15 | |
16 | namespace torch { |
17 | namespace lazy { |
18 | |
19 | // Backend should extend it and define their own supported hardware types. |
20 | struct TORCH_API BackendDeviceType { |
21 | int8_t type{(int8_t)at::kCPU}; |
22 | // Note: previous default value was '0', which actually maps to at::kCPU, at |
23 | // least now it is explicit, we may want to make default/undefined semantics |
24 | // more clear though |
25 | BackendDeviceType() : type((int8_t)at::kCPU) {} |
26 | BackendDeviceType(int8_t type) : type(type) {} |
27 | |
28 | virtual ~BackendDeviceType() = default; |
29 | virtual std::string toString() const { |
30 | return "Unknown"; |
31 | } |
32 | }; |
33 | |
34 | class TORCH_API BackendDevice { |
35 | public: |
36 | // The default constructor will set both the device type and ordinal |
37 | // to backend specific defaults. |
38 | BackendDevice(); |
39 | BackendDevice(std::shared_ptr<BackendDeviceType>&& type, int64_t ordinal); |
40 | |
41 | int8_t type() const; |
42 | int64_t ordinal() const { |
43 | return ordinal_; |
44 | } |
45 | |
46 | bool operator==(const BackendDevice& other) const { |
47 | return compare(other) == 0; |
48 | } |
49 | bool operator!=(const BackendDevice& other) const { |
50 | return compare(other) != 0; |
51 | } |
52 | bool operator<(const BackendDevice& rhs) const { |
53 | return compare(rhs) < 0; |
54 | } |
55 | |
56 | std::string toString() const; |
57 | |
58 | private: |
59 | int compare(const BackendDevice& rhs) const; |
60 | |
61 | // Use shared_ptr instead of unique_ptr so that BackendDevice can be copied. |
62 | std::shared_ptr<BackendDeviceType> type_; |
63 | int64_t ordinal_; |
64 | }; |
65 | |
66 | TORCH_API std::ostream& operator<<( |
67 | std::ostream& os, |
68 | const BackendDevice& device); |
69 | |
70 | // Helpers for converting a c10::Device to BackendDevice and vice versa. |
71 | TORCH_API BackendDevice atenDeviceToBackendDevice(const c10::Device& device); |
72 | TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device); |
73 | |
74 | // Tries to extract the backend device out of the lazy tensor. Returns nullopt |
75 | // if the input is not a lazy tensor. |
76 | TORCH_API c10::optional<BackendDevice> GetBackendDevice( |
77 | const at::ITensorListRef tensors); |
78 | TORCH_API c10::optional<BackendDevice> GetBackendDevice( |
79 | const at::TensorList tensors); |
80 | TORCH_API c10::optional<BackendDevice> GetBackendDevice( |
81 | const at::Tensor& tensor); |
82 | TORCH_API c10::optional<BackendDevice> GetBackendDevice( |
83 | const c10::optional<c10::Device> device); |
84 | |
85 | // For variadic template. |
86 | TORCH_API c10::optional<BackendDevice> GetBackendDevice(); |
87 | |
88 | template <typename T, typename... Args> |
89 | c10::optional<BackendDevice> GetBackendDevice( |
90 | const T& tensor, |
91 | const Args&... forward_tensors) { |
92 | auto optional_device = GetBackendDevice(tensor); |
93 | if (optional_device) { |
94 | return optional_device; |
95 | } |
96 | return GetBackendDevice(forward_tensors...); |
97 | } |
98 | |
99 | } // namespace lazy |
100 | } // namespace torch |
101 |