1 | #pragma once |
2 | |
3 | #include <ATen/Tensor.h> |
4 | #include <torch/csrc/lazy/backend/backend_data.h> |
5 | #include <torch/csrc/lazy/backend/backend_device.h> |
6 | #include <torch/csrc/lazy/backend/lowering_context.h> |
7 | #include <torch/csrc/lazy/core/lazy_graph_executor.h> |
8 | #include <torch/csrc/lazy/core/shape.h> |
9 | #include <torch/csrc/lazy/core/tensor.h> |
10 | #include <atomic> |
11 | |
12 | namespace torch { |
13 | namespace lazy { |
14 | |
15 | struct IrBuilder; |
16 | |
17 | /** |
18 | * Work in progress- don't treat this as a stable interface yet! |
19 | */ |
20 | class TORCH_API BackendImplInterface { |
21 | public: |
22 | virtual ~BackendImplInterface() = default; |
23 | |
24 | /** |
25 | * Initialization/Teardown |
26 | * */ |
27 | // No-op by default. Allows custom functionality to be exposed through |
28 | // extension bindings. |
29 | virtual void InitializeAtenBindings() const {} |
30 | |
31 | virtual void PrepareToExit() const = 0; |
32 | |
33 | /** |
34 | * Configuration |
35 | * */ |
36 | |
37 | virtual void SetRngSeed(size_t seed) const = 0; |
38 | |
39 | /** |
40 | * IR Tracing |
41 | * */ |
42 | |
43 | virtual const IrBuilder* GetIrBuilder() const = 0; |
44 | |
45 | /** |
46 | * Data Transfer |
47 | * */ |
48 | |
49 | virtual BackendDataPtr MakeComputationDataFromTensor( |
50 | const at::Tensor& tensor, |
51 | const Shape& shape, |
52 | const BackendDevice& device) const = 0; |
53 | virtual BackendDataPtr MakeComputationDataFromScalar( |
54 | const at::Scalar& scalar, |
55 | const torch::lazy::BackendDevice& device) const = 0; |
56 | virtual BackendDataPtr CreateDataPlaceholder( |
57 | const BackendDevice& device, |
58 | const Shape& shape) const = 0; |
59 | |
60 | // Gets backend data if the node is a device data node. Otherwise returns |
61 | // nullptr |
62 | virtual BackendDataPtr GetComputationDataFromNode(const Node*) const = 0; |
63 | |
64 | virtual at::Tensor MakeTensorFromComputationData( |
65 | const BackendDataPtr data, |
66 | c10::optional<at::ScalarType> logical_scalar_type) const = 0; |
67 | |
68 | /** |
69 | * Lowering, Compilation, Execution |
70 | * */ |
71 | |
72 | virtual std::unique_ptr<LoweringContext> CreateLoweringContext( |
73 | const std::string& name, |
74 | BackendDevice device, |
75 | c10::ArrayRef<const torch::lazy::Node*> post_order, |
76 | Util::EmissionMap emit_status) const = 0; |
77 | |
78 | virtual std::unique_ptr<LoweringContext> CreateLoweringContext( |
79 | const std::string& name, |
80 | BackendDevice device) const = 0; |
81 | |
82 | // TODO(whc) need to keep this? |
83 | virtual std::vector<std::string> GetCompilationDevices( |
84 | const std::string& device, |
85 | c10::ArrayRef<std::string> devices) const = 0; |
86 | |
87 | virtual std::vector<ComputationPtr> Compile( |
88 | std::vector<ComputationPtr> instances) const = 0; |
89 | |
90 | virtual std::vector<BackendDataPtr> ExecuteComputation( |
91 | torch::lazy::ComputationPtr computation, |
92 | c10::ArrayRef<BackendDataPtr> arguments, |
93 | const BackendDevice& device) const = 0; |
94 | |
95 | /** |
96 | * Device Configuration |
97 | * */ |
98 | |
99 | // Set or get the default device type. |
100 | // For backends used with virtual c10::Devices, this configures what real |
101 | // device type the backend should use, and matters if the backend supports |
102 | // more than one type of real device. |
103 | virtual std::shared_ptr<BackendDeviceType> GetDefaultDeviceType() const = 0; |
104 | virtual void SetDefaultDeviceType(int8_t type) = 0; |
105 | |
106 | // Set or get the default device ordinal. |
107 | // For backends that supports multi-device, this configures what the |
108 | // default device the backend should use. |
109 | virtual int64_t GetDefaultDeviceOrdinal() const = 0; |
110 | virtual void SetDefaultDeviceOrdinal(int64_t) = 0; |
111 | |
112 | // Specify which aten device should be used for eager fallback |
113 | // may change depending on current 'Default' DeviceType |
114 | virtual at::DeviceType EagerFallbackDeviceType() const = 0; |
115 | |
116 | // Query all available backend devices |
117 | virtual std::vector<BackendDevice> GetBackendDevices() const = 0; |
118 | |
119 | virtual std::string CreateMetricReport() const { |
120 | return "" ; |
121 | } |
122 | |
123 | // Map a particular c10:: device to a concrete backend device |
124 | // Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are |
125 | // virtual devices, meaning they may map to a gpu, tpu, etc. behind the |
126 | // scenes. In the future, non-virtual c10:: devices may also use lazy tensors |
127 | // through a mode, in which case these APIs should still work, but should be |
128 | // identity mappings. |
129 | virtual BackendDevice GetBackendDevice(c10::Device device) const = 0; |
130 | |
131 | // TODO(whc) |
132 | // Additional APIs expected for supporting distributed training, to be |
133 | // designed |
134 | |
135 | /** |
136 | * Debug/Metrics |
137 | * */ |
138 | |
139 | // virtual std::map<std::string, Metric> GetMetrics() const = 0; |
140 | |
141 | // virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0; |
142 | |
143 | virtual std::string GetComputationBackendText( |
144 | const ComputationPtr computation) const = 0; |
145 | }; |
146 | |
147 | class TORCH_API BackendRegistrar { |
148 | public: |
149 | BackendRegistrar(const BackendImplInterface* backend_impl_interface); |
150 | }; |
151 | |
152 | TORCH_API bool hasBackend(); |
153 | TORCH_API const BackendImplInterface* getBackend(); |
154 | |
155 | TORCH_API const IrBuilder* getIrBuilder(); |
156 | |
157 | } // namespace lazy |
158 | } // namespace torch |
159 | |