1#pragma once
2
3#include <ATen/Tensor.h>
4#include <torch/csrc/lazy/backend/backend_data.h>
5#include <torch/csrc/lazy/backend/backend_device.h>
6#include <torch/csrc/lazy/backend/lowering_context.h>
7#include <torch/csrc/lazy/core/lazy_graph_executor.h>
8#include <torch/csrc/lazy/core/shape.h>
9#include <torch/csrc/lazy/core/tensor.h>
10#include <atomic>
11
12namespace torch {
13namespace lazy {
14
15struct IrBuilder;
16
17/**
18 * Work in progress- don't treat this as a stable interface yet!
19 */
20class TORCH_API BackendImplInterface {
21 public:
22 virtual ~BackendImplInterface() = default;
23
24 /**
25 * Initialization/Teardown
26 * */
27 // No-op by default. Allows custom functionality to be exposed through
28 // extension bindings.
29 virtual void InitializeAtenBindings() const {}
30
31 virtual void PrepareToExit() const = 0;
32
33 /**
34 * Configuration
35 * */
36
37 virtual void SetRngSeed(size_t seed) const = 0;
38
39 /**
40 * IR Tracing
41 * */
42
43 virtual const IrBuilder* GetIrBuilder() const = 0;
44
45 /**
46 * Data Transfer
47 * */
48
49 virtual BackendDataPtr MakeComputationDataFromTensor(
50 const at::Tensor& tensor,
51 const Shape& shape,
52 const BackendDevice& device) const = 0;
53 virtual BackendDataPtr MakeComputationDataFromScalar(
54 const at::Scalar& scalar,
55 const torch::lazy::BackendDevice& device) const = 0;
56 virtual BackendDataPtr CreateDataPlaceholder(
57 const BackendDevice& device,
58 const Shape& shape) const = 0;
59
60 // Gets backend data if the node is a device data node. Otherwise returns
61 // nullptr
62 virtual BackendDataPtr GetComputationDataFromNode(const Node*) const = 0;
63
64 virtual at::Tensor MakeTensorFromComputationData(
65 const BackendDataPtr data,
66 c10::optional<at::ScalarType> logical_scalar_type) const = 0;
67
68 /**
69 * Lowering, Compilation, Execution
70 * */
71
72 virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
73 const std::string& name,
74 BackendDevice device,
75 c10::ArrayRef<const torch::lazy::Node*> post_order,
76 Util::EmissionMap emit_status) const = 0;
77
78 virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
79 const std::string& name,
80 BackendDevice device) const = 0;
81
82 // TODO(whc) need to keep this?
83 virtual std::vector<std::string> GetCompilationDevices(
84 const std::string& device,
85 c10::ArrayRef<std::string> devices) const = 0;
86
87 virtual std::vector<ComputationPtr> Compile(
88 std::vector<ComputationPtr> instances) const = 0;
89
90 virtual std::vector<BackendDataPtr> ExecuteComputation(
91 torch::lazy::ComputationPtr computation,
92 c10::ArrayRef<BackendDataPtr> arguments,
93 const BackendDevice& device) const = 0;
94
95 /**
96 * Device Configuration
97 * */
98
99 // Set or get the default device type.
100 // For backends used with virtual c10::Devices, this configures what real
101 // device type the backend should use, and matters if the backend supports
102 // more than one type of real device.
103 virtual std::shared_ptr<BackendDeviceType> GetDefaultDeviceType() const = 0;
104 virtual void SetDefaultDeviceType(int8_t type) = 0;
105
106 // Set or get the default device ordinal.
107 // For backends that supports multi-device, this configures what the
108 // default device the backend should use.
109 virtual int64_t GetDefaultDeviceOrdinal() const = 0;
110 virtual void SetDefaultDeviceOrdinal(int64_t) = 0;
111
112 // Specify which aten device should be used for eager fallback
113 // may change depending on current 'Default' DeviceType
114 virtual at::DeviceType EagerFallbackDeviceType() const = 0;
115
116 // Query all available backend devices
117 virtual std::vector<BackendDevice> GetBackendDevices() const = 0;
118
119 virtual std::string CreateMetricReport() const {
120 return "";
121 }
122
123 // Map a particular c10:: device to a concrete backend device
124 // Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
125 // virtual devices, meaning they may map to a gpu, tpu, etc. behind the
126 // scenes. In the future, non-virtual c10:: devices may also use lazy tensors
127 // through a mode, in which case these APIs should still work, but should be
128 // identity mappings.
129 virtual BackendDevice GetBackendDevice(c10::Device device) const = 0;
130
131 // TODO(whc)
132 // Additional APIs expected for supporting distributed training, to be
133 // designed
134
135 /**
136 * Debug/Metrics
137 * */
138
139 // virtual std::map<std::string, Metric> GetMetrics() const = 0;
140
141 // virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;
142
143 virtual std::string GetComputationBackendText(
144 const ComputationPtr computation) const = 0;
145};
146
147class TORCH_API BackendRegistrar {
148 public:
149 BackendRegistrar(const BackendImplInterface* backend_impl_interface);
150};
151
152TORCH_API bool hasBackend();
153TORCH_API const BackendImplInterface* getBackend();
154
155TORCH_API const IrBuilder* getIrBuilder();
156
157} // namespace lazy
158} // namespace torch
159