1#pragma once
2
3#include <memory>
4#include <ostream>
5#include <string>
6
7#include <ATen/Tensor.h>
8#include <c10/macros/Export.h>
9#include <c10/util/Deprecated.h>
10#include <c10/util/Optional.h>
11
12namespace c10 {
13struct Device;
14}
15
16namespace torch {
17namespace lazy {
18
19// Backend should extend it and define their own supported hardware types.
20struct TORCH_API BackendDeviceType {
21 int8_t type{(int8_t)at::kCPU};
22 // Note: previous default value was '0', which actually maps to at::kCPU, at
23 // least now it is explicit, we may want to make default/undefined semantics
24 // more clear though
25 BackendDeviceType() : type((int8_t)at::kCPU) {}
26 BackendDeviceType(int8_t type) : type(type) {}
27
28 virtual ~BackendDeviceType() = default;
29 virtual std::string toString() const {
30 return "Unknown";
31 }
32};
33
34class TORCH_API BackendDevice {
35 public:
36 // The default constructor will set both the device type and ordinal
37 // to backend specific defaults.
38 BackendDevice();
39 BackendDevice(std::shared_ptr<BackendDeviceType>&& type, int64_t ordinal);
40
41 int8_t type() const;
42 int64_t ordinal() const {
43 return ordinal_;
44 }
45
46 bool operator==(const BackendDevice& other) const {
47 return compare(other) == 0;
48 }
49 bool operator!=(const BackendDevice& other) const {
50 return compare(other) != 0;
51 }
52 bool operator<(const BackendDevice& rhs) const {
53 return compare(rhs) < 0;
54 }
55
56 std::string toString() const;
57
58 private:
59 int compare(const BackendDevice& rhs) const;
60
61 // Use shared_ptr instead of unique_ptr so that BackendDevice can be copied.
62 std::shared_ptr<BackendDeviceType> type_;
63 int64_t ordinal_;
64};
65
66TORCH_API std::ostream& operator<<(
67 std::ostream& os,
68 const BackendDevice& device);
69
70// Helpers for converting a c10::Device to BackendDevice and vice versa.
71TORCH_API BackendDevice atenDeviceToBackendDevice(const c10::Device& device);
72TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device);
73
74// Tries to extract the backend device out of the lazy tensor. Returns nullopt
75// if the input is not a lazy tensor.
76TORCH_API c10::optional<BackendDevice> GetBackendDevice(
77 const at::ITensorListRef tensors);
78TORCH_API c10::optional<BackendDevice> GetBackendDevice(
79 const at::TensorList tensors);
80TORCH_API c10::optional<BackendDevice> GetBackendDevice(
81 const at::Tensor& tensor);
82TORCH_API c10::optional<BackendDevice> GetBackendDevice(
83 const c10::optional<c10::Device> device);
84
85// For variadic template.
86TORCH_API c10::optional<BackendDevice> GetBackendDevice();
87
88template <typename T, typename... Args>
89c10::optional<BackendDevice> GetBackendDevice(
90 const T& tensor,
91 const Args&... forward_tensors) {
92 auto optional_device = GetBackendDevice(tensor);
93 if (optional_device) {
94 return optional_device;
95 }
96 return GetBackendDevice(forward_tensors...);
97}
98
99} // namespace lazy
100} // namespace torch
101