1#include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
2
3#include <ATen/Functions.h>
4#include <torch/csrc/lazy/backend/backend_device.h>
5#include <torch/csrc/lazy/core/lazy_graph_executor.h>
6#include <torch/csrc/lazy/generated/LazyNativeFunctions.h>
7#include <torch/csrc/lazy/ts_backend/config.h>
8#include <torch/csrc/lazy/ts_backend/ir_builder.h>
9#include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
10#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
11#include <memory>
12
13namespace at {
14// This function is defined in the codegenerated RegisterDispatchKey.cpp file.
15// For the TorchScript backend, we have a special case where the registration
16// does not happen immediately (at static initialization time), so that if an
17// external backend is loaded, it has a chance to register itself, and
18// TorchScript only registers itself if explicitly initialized
19extern TORCH_API void RegisterTorchScriptLazyNativeFunctions();
20extern TORCH_API void RegisterTorchScriptAutogradLazyNativeFunctions();
21} // namespace at
22
23namespace torch {
24namespace lazy {
25
26struct TSBackendDeviceType : public BackendDeviceType {
27 TSBackendDeviceType() = delete;
28 TSBackendDeviceType(c10::DeviceType deviceType)
29 : BackendDeviceType((int8_t)deviceType) {
30 TORCH_CHECK(deviceType == at::kCPU || deviceType == at::kCUDA);
31 }
32
33 std::string toString() const override {
34 return c10::DeviceTypeName((c10::DeviceType)type);
35 }
36
37 c10::DeviceType c10Type() const {
38 return (c10::DeviceType)type;
39 }
40};
41
42class TSBackendImpl : public torch::lazy::BackendImplInterface {
43 public:
44 TSBackendImpl() {
45 // TODO(whc) unify how all our flags are set and parsed as envs
46 static bool env_use_cuda = std::getenv("LTC_TS_CUDA") != nullptr;
47 auto type =
48 (env_use_cuda || FLAGS_torch_lazy_ts_cuda) ? at::kCUDA : at::kCPU;
49 default_device_type_ = std::make_shared<TSBackendDeviceType>(type);
50 }
51
52 const IrBuilder* GetIrBuilder() const override {
53 static const IrBuilder* builder = new TorchScriptIrBuilder();
54 return builder;
55 }
56
57 std::string CreateMetricReport() const override {
58 return "TSBackendImpl: N/A";
59 }
60
61 std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
62 const std::string& name,
63 torch::lazy::BackendDevice device,
64 c10::ArrayRef<const torch::lazy::Node*> post_order,
65 torch::lazy::Util::EmissionMap emit_status) const override {
66 return std::make_unique<torch::lazy::TSLoweringContext>(
67 name, device, post_order, emit_status);
68 }
69
70 std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
71 const std::string& name,
72 torch::lazy::BackendDevice device) const override {
73 return std::make_unique<torch::lazy::TSLoweringContext>(name, device);
74 }
75
76 std::vector<std::string> GetCompilationDevices(
77 const std::string& device,
78 c10::ArrayRef<std::string> devices) const override {
79 return std::vector<std::string>(devices.begin(), devices.end());
80 }
81
82 at::Tensor MakeTensorFromComputationData(
83 const torch::lazy::BackendDataPtr data,
84 c10::optional<at::ScalarType> logical_scalar_type) const override {
85 const auto ts_data = std::static_pointer_cast<TSData>(data);
86 return ts_data->data();
87 }
88
89 torch::lazy::BackendDataPtr MakeComputationDataFromTensor(
90 const at::Tensor& tensor,
91 const torch::lazy::Shape& shape,
92 const torch::lazy::BackendDevice& device) const override {
93 at::TensorOptions options = tensor.options().device(
94 default_device_type_->c10Type(), device.ordinal());
95 if (tensor.device().type() == default_device_type_->c10Type() &&
96 default_device_type_->c10Type() == at::kCUDA) {
97 return std::make_shared<TSData>(
98 tensor.to(options, /*non_blocking=*/true), shape, device);
99 } else if (tensor.device().type() == at::kCPU && tensor.numel() == 1) {
100 // calling .item() on singleton cpu tensor is fast, and using fill is a
101 // safe, async way to copy cpu to cuda for a single value
102 auto device_tensor = at::full(tensor.sizes(), tensor.item(), options);
103 return std::make_shared<TSData>(device_tensor, shape, device);
104 } else {
105 return std::make_shared<TSData>(
106 tensor.to(options, /*non_blocking=*/false), shape, device);
107 }
108 }
109
110 torch::lazy::BackendDataPtr MakeComputationDataFromScalar(
111 const at::Scalar& scalar,
112 const torch::lazy::BackendDevice& device) const override {
113 return std::make_shared<TSData>(scalar, device);
114 }
115
116 torch::lazy::BackendDataPtr GetComputationDataFromNode(
117 const Node* node) const override {
118 auto* device_data_node = DeviceData::Cast(node);
119 if (!device_data_node) {
120 return nullptr;
121 }
122 return device_data_node->data();
123 }
124
125 std::string GetComputationBackendText(
126 const torch::lazy::ComputationPtr computation) const override {
127 auto ts_computation =
128 static_cast<torch::lazy::TSComputation*>(computation.get());
129 return ts_computation->graph()->toString();
130 }
131
132 //////////////computation client interfaces///////////////////////
133
134 public:
135 torch::lazy::BackendDataPtr CreateDataPlaceholder(
136 const torch::lazy::BackendDevice& device,
137 const torch::lazy::Shape& shape) const override;
138
139 std::vector<torch::lazy::ComputationPtr> Compile(
140 std::vector<torch::lazy::ComputationPtr> instances) const override;
141
142 std::vector<torch::lazy::BackendDataPtr> ExecuteComputation(
143 torch::lazy::ComputationPtr computation,
144 c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
145 const torch::lazy::BackendDevice& device) const override;
146
147 std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType()
148 const override {
149 return default_device_type_;
150 }
151
152 at::DeviceType EagerFallbackDeviceType() const override;
153
154 void SetDefaultDeviceType(int8_t type) override {
155 default_device_type_ = std::make_shared<TSBackendDeviceType>(
156 static_cast<c10::DeviceType>(type));
157 }
158
159 int64_t GetDefaultDeviceOrdinal() const override {
160 return default_device_ordinal_;
161 }
162
163 void SetDefaultDeviceOrdinal(int64_t ordinal) override {
164 default_device_ordinal_ = ordinal;
165 }
166
167 std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override;
168
169 torch::lazy::BackendDevice GetBackendDevice(
170 c10::Device device) const override;
171
172 void SetRngSeed(size_t seed) const override {
173 LOG(FATAL) << "Not implemented yet.";
174 }
175
176 // std::map<std::string, Metric> GetMetrics() const override { return {}; }
177
178 // MemoryInfo GetMemoryInfo(const std::string& device) override {
179 // LOG(FATAL) << "Not implemented yet.";
180 // }
181
182 void PrepareToExit() const override;
183
184 private:
185 std::shared_ptr<TSBackendDeviceType> default_device_type_;
186 int64_t default_device_ordinal_{0};
187};
188
189torch::lazy::BackendDataPtr TSBackendImpl::CreateDataPlaceholder(
190 const torch::lazy::BackendDevice& device,
191 const torch::lazy::Shape& shape) const {
192 return std::make_shared<TSData>(shape, device);
193}
194
195std::vector<torch::lazy::ComputationPtr> TSBackendImpl::Compile(
196 std::vector<torch::lazy::ComputationPtr> instances) const {
197 for (const auto& instance : instances) {
198 auto ts_computation =
199 static_cast<torch::lazy::TSComputation*>(instance.get());
200 if (!ts_computation->in_mark_step) {
201 LOG(WARNING) << "Compile outside of mark step";
202 }
203 }
204 return instances;
205}
206
207std::vector<torch::lazy::BackendDataPtr> TSBackendImpl::ExecuteComputation(
208 torch::lazy::ComputationPtr computation,
209 c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
210 const torch::lazy::BackendDevice& device) const {
211 auto ts_computation =
212 std::dynamic_pointer_cast<torch::lazy::TSComputation>(computation);
213 TORCH_CHECK(ts_computation, "Computation isn't TSComputation");
214 torch::jit::GraphExecutor& graph_executor = ts_computation->graph_executor();
215 std::vector<torch::jit::IValue> stack;
216 for (const auto& argument : arguments) {
217 const auto ts_data = std::static_pointer_cast<TSData>(argument);
218 if (ts_data->scalar.has_value()) {
219 stack.emplace_back(ts_data->scalar.value());
220 } else {
221 // TODO(whc) should this check be made more general? it's written somewhat
222 // oddly
223 CHECK(
224 static_cast<c10::DeviceType>(default_device_type_->type) !=
225 at::kCUDA ||
226 ts_data->data().device().type() == at::kCUDA);
227 stack.emplace_back(ts_data->data());
228 }
229 }
230 graph_executor.run(stack);
231 std::vector<torch::lazy::BackendDataPtr> results;
232 for (torch::jit::IValue component : stack) {
233 at::Tensor result = component.toTensor();
234 at::IntArrayRef result_sizes = result.sizes();
235 torch::lazy::Shape shape(
236 result.scalar_type(),
237 std::vector<int64_t>(result_sizes.begin(), result_sizes.end()));
238 results.push_back(std::make_shared<TSData>(result, shape, device));
239 }
240 return results;
241}
242
243std::vector<torch::lazy::BackendDevice> TSBackendImpl::GetBackendDevices()
244 const {
245 std::vector<torch::lazy::BackendDevice> devices;
246 // TODO(whc) figure out how to query available devices from pytorch
247 devices.emplace_back(GetBackendDevice(c10::Device(c10::kCPU, 0)));
248 devices.emplace_back(GetBackendDevice(c10::Device(c10::kCUDA, 0)));
249 return devices;
250}
251
252torch::lazy::BackendDevice TSBackendImpl::GetBackendDevice(
253 c10::Device device) const {
254 // Note, we ignore the device type specified by the c10::Device since it is
255 // expected to be a virtual device (lazy::), but we need to change this when
256 // we support lazy as a mode
257 return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index());
258}
259
260void TSBackendImpl::PrepareToExit() const {}
261
262c10::DeviceType TSBackendImpl::EagerFallbackDeviceType() const {
263 // For TS backend, hardware device _is_ eager device
264 return (c10::DeviceType)GetDefaultDeviceType()->type;
265}
266
267torch::lazy::BackendImplInterface* GetTSBackendImpl() {
268 static TSBackendImpl* ts_backend_impl = new TSBackendImpl();
269 return ts_backend_impl;
270}
271
272void InitTorchScriptBackend() {
273 at::RegisterTorchScriptLazyNativeFunctions();
274 at::RegisterTorchScriptAutogradLazyNativeFunctions();
275 register_ts_ltc_eager_fallback();
276 static std::unique_ptr<BackendRegistrar> s_registrar;
277 s_registrar = std::make_unique<BackendRegistrar>(GetTSBackendImpl());
278
279 static LazyGraphExecutor* executor = new LazyGraphExecutor();
280 LazyGraphExecutor::Register(executor);
281}
282
283} // namespace lazy
284} // namespace torch
285