1#include <torch/csrc/lazy/backend/backend_device.h>
2
3#include <c10/core/Device.h>
4#include <c10/util/Exception.h>
5#include <c10/util/Optional.h>
6#include <c10/util/StringUtil.h>
7#include <torch/csrc/lazy/backend/backend_interface.h>
8#include <torch/csrc/lazy/core/tensor.h>
9
10namespace torch {
11namespace lazy {
12
13BackendDevice::BackendDevice()
14 : type_(getBackend()->GetDefaultDeviceType()),
15 ordinal_(getBackend()->GetDefaultDeviceOrdinal()) {}
16
17BackendDevice::BackendDevice(
18 std::shared_ptr<BackendDeviceType>&& type,
19 int64_t ordinal)
20 : type_(std::move(type)), ordinal_(ordinal) {}
21
22int8_t BackendDevice::type() const {
23 TORCH_INTERNAL_ASSERT(type_);
24 return type_->type;
25}
26
27std::string BackendDevice::toString() const {
28 TORCH_INTERNAL_ASSERT(type_);
29 return c10::str(type_->toString(), ordinal_);
30}
31
32int BackendDevice::compare(const BackendDevice& rhs) const {
33 if (type() != rhs.type()) {
34 return type() < rhs.type() ? -1 : +1;
35 }
36 return ordinal_ < rhs.ordinal_ ? -1 : (ordinal_ > rhs.ordinal_ ? +1 : 0);
37}
38
39std::ostream& operator<<(std::ostream& os, const BackendDevice& device) {
40 os << device.toString();
41 return os;
42}
43
44BackendDevice atenDeviceToBackendDevice(const c10::Device& device) {
45 TORCH_CHECK(device.type() == at::kLazy, device);
46 int64_t ordinal = device.has_index()
47 ? device.index()
48 : getBackend()->GetDefaultDeviceOrdinal();
49 return BackendDevice(getBackend()->GetDefaultDeviceType(), ordinal);
50}
51
52// TODO(whc) refactor this: we need to support non 1 on 1 mapping for torch/XLA.
53c10::Device backendDeviceToAtenDevice(const BackendDevice& device) {
54 return c10::Device(at::kLazy, device.ordinal());
55}
56
57c10::optional<BackendDevice> GetBackendDevice(at::ITensorListRef tensors) {
58 for (auto& tensor : tensors) {
59 if (auto lt = TryGetLtcTensor(tensor)) {
60 return lt->GetDevice();
61 }
62 }
63 return c10::nullopt;
64}
65
66c10::optional<BackendDevice> GetBackendDevice(at::TensorList tensors) {
67 return GetBackendDevice(at::ITensorListRef(tensors));
68}
69
70c10::optional<BackendDevice> GetBackendDevice(const at::Tensor& tensor) {
71 if (auto lt = TryGetLtcTensor(tensor)) {
72 return lt->GetDevice();
73 }
74 return c10::nullopt;
75}
76
77c10::optional<BackendDevice> GetBackendDevice(
78 const c10::optional<c10::Device> device) {
79 if (device) {
80 return c10::make_optional(atenDeviceToBackendDevice(*device));
81 }
82 return c10::nullopt;
83}
84
85c10::optional<BackendDevice> GetBackendDevice() {
86 return c10::nullopt;
87}
88
89} // namespace lazy
90} // namespace torch
91