1 | #include <torch/csrc/lazy/backend/backend_device.h> |
---|---|
2 | |
3 | #include <c10/core/Device.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/Optional.h> |
6 | #include <c10/util/StringUtil.h> |
7 | #include <torch/csrc/lazy/backend/backend_interface.h> |
8 | #include <torch/csrc/lazy/core/tensor.h> |
9 | |
10 | namespace torch { |
11 | namespace lazy { |
12 | |
13 | BackendDevice::BackendDevice() |
14 | : type_(getBackend()->GetDefaultDeviceType()), |
15 | ordinal_(getBackend()->GetDefaultDeviceOrdinal()) {} |
16 | |
17 | BackendDevice::BackendDevice( |
18 | std::shared_ptr<BackendDeviceType>&& type, |
19 | int64_t ordinal) |
20 | : type_(std::move(type)), ordinal_(ordinal) {} |
21 | |
22 | int8_t BackendDevice::type() const { |
23 | TORCH_INTERNAL_ASSERT(type_); |
24 | return type_->type; |
25 | } |
26 | |
27 | std::string BackendDevice::toString() const { |
28 | TORCH_INTERNAL_ASSERT(type_); |
29 | return c10::str(type_->toString(), ordinal_); |
30 | } |
31 | |
32 | int BackendDevice::compare(const BackendDevice& rhs) const { |
33 | if (type() != rhs.type()) { |
34 | return type() < rhs.type() ? -1 : +1; |
35 | } |
36 | return ordinal_ < rhs.ordinal_ ? -1 : (ordinal_ > rhs.ordinal_ ? +1 : 0); |
37 | } |
38 | |
39 | std::ostream& operator<<(std::ostream& os, const BackendDevice& device) { |
40 | os << device.toString(); |
41 | return os; |
42 | } |
43 | |
44 | BackendDevice atenDeviceToBackendDevice(const c10::Device& device) { |
45 | TORCH_CHECK(device.type() == at::kLazy, device); |
46 | int64_t ordinal = device.has_index() |
47 | ? device.index() |
48 | : getBackend()->GetDefaultDeviceOrdinal(); |
49 | return BackendDevice(getBackend()->GetDefaultDeviceType(), ordinal); |
50 | } |
51 | |
52 | // TODO(whc) refactor this: we need to support non 1 on 1 mapping for torch/XLA. |
53 | c10::Device backendDeviceToAtenDevice(const BackendDevice& device) { |
54 | return c10::Device(at::kLazy, device.ordinal()); |
55 | } |
56 | |
57 | c10::optional<BackendDevice> GetBackendDevice(at::ITensorListRef tensors) { |
58 | for (auto& tensor : tensors) { |
59 | if (auto lt = TryGetLtcTensor(tensor)) { |
60 | return lt->GetDevice(); |
61 | } |
62 | } |
63 | return c10::nullopt; |
64 | } |
65 | |
66 | c10::optional<BackendDevice> GetBackendDevice(at::TensorList tensors) { |
67 | return GetBackendDevice(at::ITensorListRef(tensors)); |
68 | } |
69 | |
70 | c10::optional<BackendDevice> GetBackendDevice(const at::Tensor& tensor) { |
71 | if (auto lt = TryGetLtcTensor(tensor)) { |
72 | return lt->GetDevice(); |
73 | } |
74 | return c10::nullopt; |
75 | } |
76 | |
77 | c10::optional<BackendDevice> GetBackendDevice( |
78 | const c10::optional<c10::Device> device) { |
79 | if (device) { |
80 | return c10::make_optional(atenDeviceToBackendDevice(*device)); |
81 | } |
82 | return c10::nullopt; |
83 | } |
84 | |
85 | c10::optional<BackendDevice> GetBackendDevice() { |
86 | return c10::nullopt; |
87 | } |
88 | |
89 | } // namespace lazy |
90 | } // namespace torch |
91 |