1 | #include <torch/csrc/lazy/ts_backend/ts_backend_impl.h> |
2 | |
3 | #include <ATen/Functions.h> |
4 | #include <torch/csrc/lazy/backend/backend_device.h> |
5 | #include <torch/csrc/lazy/core/lazy_graph_executor.h> |
6 | #include <torch/csrc/lazy/generated/LazyNativeFunctions.h> |
7 | #include <torch/csrc/lazy/ts_backend/config.h> |
8 | #include <torch/csrc/lazy/ts_backend/ir_builder.h> |
9 | #include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h> |
10 | #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h> |
11 | #include <memory> |
12 | |
13 | namespace at { |
14 | // This function is defined in the codegenerated RegisterDispatchKey.cpp file. |
15 | // For the TorchScript backend, we have a special case where the registration |
16 | // does not happen immediately (at static initialization time), so that if an |
17 | // external backend is loaded, it has a chance to register itself, and |
18 | // TorchScript only registers itself if explicitly initialized |
19 | extern TORCH_API void RegisterTorchScriptLazyNativeFunctions(); |
20 | extern TORCH_API void RegisterTorchScriptAutogradLazyNativeFunctions(); |
21 | } // namespace at |
22 | |
23 | namespace torch { |
24 | namespace lazy { |
25 | |
26 | struct TSBackendDeviceType : public BackendDeviceType { |
27 | TSBackendDeviceType() = delete; |
28 | TSBackendDeviceType(c10::DeviceType deviceType) |
29 | : BackendDeviceType((int8_t)deviceType) { |
30 | TORCH_CHECK(deviceType == at::kCPU || deviceType == at::kCUDA); |
31 | } |
32 | |
33 | std::string toString() const override { |
34 | return c10::DeviceTypeName((c10::DeviceType)type); |
35 | } |
36 | |
37 | c10::DeviceType c10Type() const { |
38 | return (c10::DeviceType)type; |
39 | } |
40 | }; |
41 | |
42 | class TSBackendImpl : public torch::lazy::BackendImplInterface { |
43 | public: |
44 | TSBackendImpl() { |
45 | // TODO(whc) unify how all our flags are set and parsed as envs |
46 | static bool env_use_cuda = std::getenv("LTC_TS_CUDA" ) != nullptr; |
47 | auto type = |
48 | (env_use_cuda || FLAGS_torch_lazy_ts_cuda) ? at::kCUDA : at::kCPU; |
49 | default_device_type_ = std::make_shared<TSBackendDeviceType>(type); |
50 | } |
51 | |
52 | const IrBuilder* GetIrBuilder() const override { |
53 | static const IrBuilder* builder = new TorchScriptIrBuilder(); |
54 | return builder; |
55 | } |
56 | |
57 | std::string CreateMetricReport() const override { |
58 | return "TSBackendImpl: N/A" ; |
59 | } |
60 | |
61 | std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext( |
62 | const std::string& name, |
63 | torch::lazy::BackendDevice device, |
64 | c10::ArrayRef<const torch::lazy::Node*> post_order, |
65 | torch::lazy::Util::EmissionMap emit_status) const override { |
66 | return std::make_unique<torch::lazy::TSLoweringContext>( |
67 | name, device, post_order, emit_status); |
68 | } |
69 | |
70 | std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext( |
71 | const std::string& name, |
72 | torch::lazy::BackendDevice device) const override { |
73 | return std::make_unique<torch::lazy::TSLoweringContext>(name, device); |
74 | } |
75 | |
76 | std::vector<std::string> GetCompilationDevices( |
77 | const std::string& device, |
78 | c10::ArrayRef<std::string> devices) const override { |
79 | return std::vector<std::string>(devices.begin(), devices.end()); |
80 | } |
81 | |
82 | at::Tensor MakeTensorFromComputationData( |
83 | const torch::lazy::BackendDataPtr data, |
84 | c10::optional<at::ScalarType> logical_scalar_type) const override { |
85 | const auto ts_data = std::static_pointer_cast<TSData>(data); |
86 | return ts_data->data(); |
87 | } |
88 | |
89 | torch::lazy::BackendDataPtr MakeComputationDataFromTensor( |
90 | const at::Tensor& tensor, |
91 | const torch::lazy::Shape& shape, |
92 | const torch::lazy::BackendDevice& device) const override { |
93 | at::TensorOptions options = tensor.options().device( |
94 | default_device_type_->c10Type(), device.ordinal()); |
95 | if (tensor.device().type() == default_device_type_->c10Type() && |
96 | default_device_type_->c10Type() == at::kCUDA) { |
97 | return std::make_shared<TSData>( |
98 | tensor.to(options, /*non_blocking=*/true), shape, device); |
99 | } else if (tensor.device().type() == at::kCPU && tensor.numel() == 1) { |
100 | // calling .item() on singleton cpu tensor is fast, and using fill is a |
101 | // safe, async way to copy cpu to cuda for a single value |
102 | auto device_tensor = at::full(tensor.sizes(), tensor.item(), options); |
103 | return std::make_shared<TSData>(device_tensor, shape, device); |
104 | } else { |
105 | return std::make_shared<TSData>( |
106 | tensor.to(options, /*non_blocking=*/false), shape, device); |
107 | } |
108 | } |
109 | |
110 | torch::lazy::BackendDataPtr MakeComputationDataFromScalar( |
111 | const at::Scalar& scalar, |
112 | const torch::lazy::BackendDevice& device) const override { |
113 | return std::make_shared<TSData>(scalar, device); |
114 | } |
115 | |
116 | torch::lazy::BackendDataPtr GetComputationDataFromNode( |
117 | const Node* node) const override { |
118 | auto* device_data_node = DeviceData::Cast(node); |
119 | if (!device_data_node) { |
120 | return nullptr; |
121 | } |
122 | return device_data_node->data(); |
123 | } |
124 | |
125 | std::string GetComputationBackendText( |
126 | const torch::lazy::ComputationPtr computation) const override { |
127 | auto ts_computation = |
128 | static_cast<torch::lazy::TSComputation*>(computation.get()); |
129 | return ts_computation->graph()->toString(); |
130 | } |
131 | |
132 | //////////////computation client interfaces/////////////////////// |
133 | |
134 | public: |
135 | torch::lazy::BackendDataPtr CreateDataPlaceholder( |
136 | const torch::lazy::BackendDevice& device, |
137 | const torch::lazy::Shape& shape) const override; |
138 | |
139 | std::vector<torch::lazy::ComputationPtr> Compile( |
140 | std::vector<torch::lazy::ComputationPtr> instances) const override; |
141 | |
142 | std::vector<torch::lazy::BackendDataPtr> ExecuteComputation( |
143 | torch::lazy::ComputationPtr computation, |
144 | c10::ArrayRef<torch::lazy::BackendDataPtr> arguments, |
145 | const torch::lazy::BackendDevice& device) const override; |
146 | |
147 | std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() |
148 | const override { |
149 | return default_device_type_; |
150 | } |
151 | |
152 | at::DeviceType EagerFallbackDeviceType() const override; |
153 | |
154 | void SetDefaultDeviceType(int8_t type) override { |
155 | default_device_type_ = std::make_shared<TSBackendDeviceType>( |
156 | static_cast<c10::DeviceType>(type)); |
157 | } |
158 | |
159 | int64_t GetDefaultDeviceOrdinal() const override { |
160 | return default_device_ordinal_; |
161 | } |
162 | |
163 | void SetDefaultDeviceOrdinal(int64_t ordinal) override { |
164 | default_device_ordinal_ = ordinal; |
165 | } |
166 | |
167 | std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override; |
168 | |
169 | torch::lazy::BackendDevice GetBackendDevice( |
170 | c10::Device device) const override; |
171 | |
172 | void SetRngSeed(size_t seed) const override { |
173 | LOG(FATAL) << "Not implemented yet." ; |
174 | } |
175 | |
176 | // std::map<std::string, Metric> GetMetrics() const override { return {}; } |
177 | |
178 | // MemoryInfo GetMemoryInfo(const std::string& device) override { |
179 | // LOG(FATAL) << "Not implemented yet."; |
180 | // } |
181 | |
182 | void PrepareToExit() const override; |
183 | |
184 | private: |
185 | std::shared_ptr<TSBackendDeviceType> default_device_type_; |
186 | int64_t default_device_ordinal_{0}; |
187 | }; |
188 | |
189 | torch::lazy::BackendDataPtr TSBackendImpl::CreateDataPlaceholder( |
190 | const torch::lazy::BackendDevice& device, |
191 | const torch::lazy::Shape& shape) const { |
192 | return std::make_shared<TSData>(shape, device); |
193 | } |
194 | |
195 | std::vector<torch::lazy::ComputationPtr> TSBackendImpl::Compile( |
196 | std::vector<torch::lazy::ComputationPtr> instances) const { |
197 | for (const auto& instance : instances) { |
198 | auto ts_computation = |
199 | static_cast<torch::lazy::TSComputation*>(instance.get()); |
200 | if (!ts_computation->in_mark_step) { |
201 | LOG(WARNING) << "Compile outside of mark step" ; |
202 | } |
203 | } |
204 | return instances; |
205 | } |
206 | |
207 | std::vector<torch::lazy::BackendDataPtr> TSBackendImpl::ExecuteComputation( |
208 | torch::lazy::ComputationPtr computation, |
209 | c10::ArrayRef<torch::lazy::BackendDataPtr> arguments, |
210 | const torch::lazy::BackendDevice& device) const { |
211 | auto ts_computation = |
212 | std::dynamic_pointer_cast<torch::lazy::TSComputation>(computation); |
213 | TORCH_CHECK(ts_computation, "Computation isn't TSComputation" ); |
214 | torch::jit::GraphExecutor& graph_executor = ts_computation->graph_executor(); |
215 | std::vector<torch::jit::IValue> stack; |
216 | for (const auto& argument : arguments) { |
217 | const auto ts_data = std::static_pointer_cast<TSData>(argument); |
218 | if (ts_data->scalar.has_value()) { |
219 | stack.emplace_back(ts_data->scalar.value()); |
220 | } else { |
221 | // TODO(whc) should this check be made more general? it's written somewhat |
222 | // oddly |
223 | CHECK( |
224 | static_cast<c10::DeviceType>(default_device_type_->type) != |
225 | at::kCUDA || |
226 | ts_data->data().device().type() == at::kCUDA); |
227 | stack.emplace_back(ts_data->data()); |
228 | } |
229 | } |
230 | graph_executor.run(stack); |
231 | std::vector<torch::lazy::BackendDataPtr> results; |
232 | for (torch::jit::IValue component : stack) { |
233 | at::Tensor result = component.toTensor(); |
234 | at::IntArrayRef result_sizes = result.sizes(); |
235 | torch::lazy::Shape shape( |
236 | result.scalar_type(), |
237 | std::vector<int64_t>(result_sizes.begin(), result_sizes.end())); |
238 | results.push_back(std::make_shared<TSData>(result, shape, device)); |
239 | } |
240 | return results; |
241 | } |
242 | |
243 | std::vector<torch::lazy::BackendDevice> TSBackendImpl::GetBackendDevices() |
244 | const { |
245 | std::vector<torch::lazy::BackendDevice> devices; |
246 | // TODO(whc) figure out how to query available devices from pytorch |
247 | devices.emplace_back(GetBackendDevice(c10::Device(c10::kCPU, 0))); |
248 | devices.emplace_back(GetBackendDevice(c10::Device(c10::kCUDA, 0))); |
249 | return devices; |
250 | } |
251 | |
252 | torch::lazy::BackendDevice TSBackendImpl::GetBackendDevice( |
253 | c10::Device device) const { |
254 | // Note, we ignore the device type specified by the c10::Device since it is |
255 | // expected to be a virtual device (lazy::), but we need to change this when |
256 | // we support lazy as a mode |
257 | return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index()); |
258 | } |
259 | |
260 | void TSBackendImpl::PrepareToExit() const {} |
261 | |
262 | c10::DeviceType TSBackendImpl::EagerFallbackDeviceType() const { |
263 | // For TS backend, hardware device _is_ eager device |
264 | return (c10::DeviceType)GetDefaultDeviceType()->type; |
265 | } |
266 | |
267 | torch::lazy::BackendImplInterface* GetTSBackendImpl() { |
268 | static TSBackendImpl* ts_backend_impl = new TSBackendImpl(); |
269 | return ts_backend_impl; |
270 | } |
271 | |
272 | void InitTorchScriptBackend() { |
273 | at::RegisterTorchScriptLazyNativeFunctions(); |
274 | at::RegisterTorchScriptAutogradLazyNativeFunctions(); |
275 | register_ts_ltc_eager_fallback(); |
276 | static std::unique_ptr<BackendRegistrar> s_registrar; |
277 | s_registrar = std::make_unique<BackendRegistrar>(GetTSBackendImpl()); |
278 | |
279 | static LazyGraphExecutor* executor = new LazyGraphExecutor(); |
280 | LazyGraphExecutor::Register(executor); |
281 | } |
282 | |
283 | } // namespace lazy |
284 | } // namespace torch |
285 | |