1 | #include <torch/csrc/lazy/python/init.h> |
2 | |
3 | #include <ATen/FunctionalTensorWrapper.h> |
4 | #include <c10/core/Device.h> |
5 | #include <torch/csrc/jit/python/pybind.h> |
6 | #include <torch/csrc/lazy/backend/backend_device.h> |
7 | #include <torch/csrc/lazy/backend/backend_interface.h> |
8 | #include <torch/csrc/lazy/core/config.h> |
9 | #include <torch/csrc/lazy/core/debug_util.h> |
10 | #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h> |
11 | #include <torch/csrc/lazy/core/ir_dump_util.h> |
12 | #include <torch/csrc/lazy/core/lazy_graph_executor.h> |
13 | #include <torch/csrc/lazy/core/metrics.h> |
14 | #include <torch/csrc/lazy/core/trie.h> |
15 | #include <torch/csrc/lazy/python/python_util.h> |
16 | #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) |
17 | #include <torch/csrc/lazy/ts_backend/ts_backend_impl.h> |
18 | #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h> |
19 | #endif // FBCODE_CAFFE2 || OVRSOURCE |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | namespace torch { |
24 | namespace lazy { |
25 | |
26 | // TODO(whc) backend 'device' related APIs are not very clear, this code could |
27 | // be simplified but it should probably be done together with |
28 | // designing/refactoring the overall approach to get/set of default eager/lazy |
29 | // device types |
30 | torch::lazy::BackendDevice GetDeviceOrCurrent(const std::string& device_str) { |
31 | if (device_str.empty()) { |
32 | getBackend()->GetDefaultDeviceType(); |
33 | return torch::lazy::BackendDevice(); |
34 | } |
35 | return torch::lazy::atenDeviceToBackendDevice(c10::Device(device_str)); |
36 | } |
37 | |
38 | std::ptrdiff_t GetTensorId(const at::Tensor& tensor) { |
39 | torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor); |
40 | return lazy_tensor->GetUniqueId(); |
41 | } |
42 | |
43 | std::string GetTensorsDump( |
44 | const std::vector<at::Tensor>& tensors, |
45 | const std::function<std::string(c10::ArrayRef<const torch::lazy::Node*>)>& |
46 | coverter) { |
47 | std::vector<const torch::lazy::Node*> nodes; |
48 | std::vector<torch::lazy::Value> values; |
49 | for (auto& tensor : tensors) { |
50 | auto inner = at::functionalization::impl::from_functional_tensor(tensor); |
51 | torch::lazy::LazyTensorPtr lazy_tensor = |
52 | torch::lazy::TryGetLtcTensor(inner); |
53 | values.push_back(lazy_tensor->GetIrValue()); |
54 | nodes.push_back(values.back().node.get()); |
55 | } |
56 | return coverter(nodes); |
57 | } |
58 | |
59 | std::vector<torch::lazy::LazyTensorPtr> GetLtcTensors( |
60 | const std::vector<at::Tensor>& tensors, |
61 | bool want_all) { |
62 | std::vector<torch::lazy::LazyTensorPtr> lazy_tensors; |
63 | lazy_tensors.reserve(tensors.size()); |
64 | if (want_all) { |
65 | for (auto& tensor : tensors) { |
66 | lazy_tensors.push_back(torch::lazy::TryGetLtcTensor(tensor)); |
67 | } |
68 | } else { |
69 | for (auto& tensor : tensors) { |
70 | auto lazy_tensor = torch::lazy::TryGetLtcTensor(tensor); |
71 | if (lazy_tensor) { |
72 | lazy_tensors.push_back(lazy_tensor); |
73 | } |
74 | } |
75 | } |
76 | return lazy_tensors; |
77 | } |
78 | |
79 | std::string GetTensorsBackendGraph(const std::vector<at::Tensor>& tensors) { |
80 | std::vector<torch::lazy::LazyTensorPtr> lazy_tensors = |
81 | GetLtcTensors(tensors, /*want_all=*/false); |
82 | return torch::lazy::LazyGraphExecutor::Get()->DumpBackendComputation( |
83 | lazy_tensors); |
84 | } |
85 | |
86 | void SyncTensors( |
87 | const std::vector<at::Tensor>& tensors, |
88 | const std::vector<std::string>& devices, |
89 | bool wait, |
90 | bool sync_ltc_data) { |
91 | std::vector<torch::lazy::LazyTensorPtr> lazy_tensors = |
92 | GetLtcTensors(tensors, /*want_all=*/false); |
93 | torch::lazy::LazyGraphExecutor::Get()->SyncTensorsGraph( |
94 | &lazy_tensors, devices, wait, sync_ltc_data); |
95 | } |
96 | |
97 | void initLazyBindings(PyObject* module) { |
98 | auto m = py::handle(module).cast<py::module>(); |
99 | auto lazy = m.def_submodule("_lazy" ); |
100 | auto lazy_ts_backend = m.def_submodule("_lazy_ts_backend" ); |
101 | |
102 | lazy.def( |
103 | "_mark_step" , |
104 | // TODO(whc) this API should probably change from vector<string> to |
105 | // vector<c10::device> but in a separate PR |
106 | [](const std::string& device_str, |
107 | const std::vector<std::string>& devices, |
108 | bool wait) { |
109 | pybind11::gil_scoped_release no_gil; |
110 | auto backend_device = GetDeviceOrCurrent(device_str); |
111 | torch::lazy::LazyGraphExecutor::Get()->SyncLiveTensorsGraph( |
112 | &backend_device, devices, wait); |
113 | torch::lazy::LazyGraphExecutor::Get()->MarkStep(backend_device); |
114 | }, |
115 | py::arg("device" ) = "" , |
116 | py::arg("devices" ), |
117 | py::arg("wait" ) = true); |
118 | lazy.def( |
119 | "_wait_device_ops" , |
120 | [](const std::vector<std::string>& devices) { |
121 | pybind11::gil_scoped_release no_gil; |
122 | // TODO: Add support of non-empty devices. |
123 | if (!devices.empty()) { |
124 | LOG(ERROR) << "Non-empty devices are not supported." ; |
125 | } |
126 | torch::lazy::LazyGraphExecutor::Get()->WaitDeviceOps({}); |
127 | }, |
128 | py::arg("devices" )); |
129 | lazy.def("_reset_metrics" , []() { |
130 | torch::lazy::MetricsArena::Get()->ResetCounters(); |
131 | torch::lazy::MetricsArena::Get()->ResetMetrics(); |
132 | }); |
133 | lazy.def("_counter_names" , []() { return torch::lazy::GetCounterNames(); }); |
134 | lazy.def( |
135 | "_metrics_report" , []() { return torch::lazy::CreateMetricReport(); }); |
136 | lazy.def("_counter_value" , [](const std::string& name) -> py::object { |
137 | torch::lazy::CounterData* data = torch::lazy::GetCounter(name); |
138 | return data != nullptr ? py::cast<int64_t>(data->Value()) : py::none(); |
139 | }); |
140 | lazy.def("_get_tensor_id" , [](const at::Tensor& tensor) { |
141 | return GetTensorId(tensor); |
142 | }); |
143 | |
144 | lazy.def( |
145 | "_get_tensors_text" , |
146 | [](const std::vector<at::Tensor>& tensors) -> std::string { |
147 | auto coverter = [](c10::ArrayRef<const torch::lazy::Node*> nodes) { |
148 | return torch::lazy::DumpUtil::ToText(nodes); |
149 | }; |
150 | return GetTensorsDump(tensors, coverter); |
151 | }); |
152 | lazy.def( |
153 | "_get_tensors_dot" , |
154 | [](const std::vector<at::Tensor>& tensors) -> std::string { |
155 | auto coverter = [](c10::ArrayRef<const torch::lazy::Node*> nodes) { |
156 | return torch::lazy::DumpUtil::ToDot(nodes); |
157 | }; |
158 | return GetTensorsDump(tensors, coverter); |
159 | }); |
160 | lazy.def( |
161 | "_get_tensors_backend" , |
162 | [](const std::vector<at::Tensor>& tensors) -> std::string { |
163 | return GetTensorsBackendGraph(tensors); |
164 | }); |
165 | lazy.def("_get_graph_hash" , [](const std::vector<at::Tensor>& tensors) { |
166 | std::vector<LazyTensorPtr> xtensors; |
167 | xtensors.reserve(tensors.size()); |
168 | for (auto& tensor : tensors) { |
169 | xtensors.emplace_back(TryGetLtcTensor(tensor)); |
170 | } |
171 | auto hash = LazyGraphExecutor::Get()->GetGraphHash(xtensors); |
172 | std::string bin((const char*)&hash, sizeof(hash)); |
173 | return py::bytes(bin); |
174 | }); |
175 | lazy.def( |
176 | "_sync_multi" , |
177 | [](const std::vector<at::Tensor>& tensors, |
178 | const std::vector<std::string>& devices, |
179 | bool wait, |
180 | bool sync_ltc_data) { |
181 | pybind11::gil_scoped_release no_gil; |
182 | SyncTensors(tensors, devices, wait, sync_ltc_data); |
183 | }, |
184 | py::arg("tensors" ), |
185 | py::arg("devices" ), |
186 | py::arg("wait" ) = true, |
187 | py::arg("sync_ltc_data" ) = true); |
188 | |
189 | lazy.def("_get_force_fallback" , []() { |
190 | return torch::lazy::getLTCForceFallback(); |
191 | }); |
192 | lazy.def("_set_force_fallback" , [](std::string newval) { |
193 | torch::lazy::getLTCForceFallback() = newval; |
194 | }); |
195 | lazy.def("_clear_ir_cache" , []() { TrieCache::Get()->Clear(); }); |
196 | lazy.def("_dump_ir_cache" , [](std::string filename) { |
197 | TrieCache::Get()->DumpToDotFile(filename); |
198 | }); |
199 | lazy.def("_set_reuse_ir" , [](bool val) { FLAGS_torch_lazy_reuse_ir = val; }); |
200 | lazy.def("_set_symbolic_shape_mode" , [](bool val) { |
201 | FLAGS_ltc_enable_symbolic_shapes = val; |
202 | }); |
203 | lazy.def("_get_symbolic_shape_mode" , []() { |
204 | return FLAGS_ltc_enable_symbolic_shapes; |
205 | }); |
206 | lazy.def("_get_default_device_type" , []() { |
207 | return getBackend()->GetDefaultDeviceType()->toString(); |
208 | }); |
209 | |
210 | lazy_ts_backend.def("_init" , []() { |
211 | #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) |
212 | torch::lazy::InitTorchScriptBackend(); |
213 | #else |
214 | TORCH_CHECK(false, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds" ); |
215 | #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) |
216 | }); |
217 | |
218 | /* |
219 | * Return tensor ids and tensors for DeviceData nodes. |
220 | * TODO(shunting) revisit this API for XLA |
221 | */ |
222 | lazy_ts_backend.def( |
223 | "_get_tensors_ts_device_data_node" , |
224 | [](const std::vector<at::Tensor>& tensors) |
225 | -> std::pair<std::vector<int64_t>, std::vector<at::IValue>> { |
226 | #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) |
227 | std::vector<const Node*> roots; |
228 | for (auto& tensor : tensors) { |
229 | auto xtensor = TryGetLtcTensor(tensor); |
230 | roots.push_back(xtensor->GetIrValue().node.get()); |
231 | } |
232 | auto post_order = Util::ComputePostOrder(roots); |
233 | std::vector<int64_t> tensor_ids; |
234 | std::vector<at::IValue> ivalues; |
235 | |
236 | std::unordered_set<BackendData::Handle> data_handles_; |
237 | for (auto nodeptr : post_order) { |
238 | if (nodeptr->op() == *torch::lazy::ltc_device_data) { |
239 | const auto backend_data = |
240 | getBackend()->GetComputationDataFromNode(nodeptr); |
241 | |
242 | auto infoptr = backend_data->info(); |
243 | auto deviceDataInfoPtr = |
244 | (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; |
245 | auto* tsDataPtr = (torch::lazy::TSData*)backend_data.get(); |
246 | |
247 | // dedup DeviceData by handle |
248 | auto handle = tsDataPtr->GetHandle(); |
249 | if (!data_handles_.insert(handle).second) { |
250 | continue; |
251 | } |
252 | tensor_ids.push_back(deviceDataInfoPtr->tensor_id); |
253 | /* |
254 | * If the TSData contains a tensor, then the tensor id will uniquely |
255 | * identify the tensor. We use that tensor id to find the tensor in |
256 | * other places: e.g. in the python forward method parameters. |
257 | * |
258 | * If the TSData contains a scalar, the tensor id itself is not |
259 | * important. We reuse the scalar value in future calls. |
260 | */ |
261 | if (tsDataPtr->HasValue()) { |
262 | ivalues.emplace_back(tsDataPtr->data()); |
263 | } else { |
264 | CHECK(tsDataPtr->scalar.has_value()); |
265 | ivalues.emplace_back(tsDataPtr->scalar.value()); |
266 | } |
267 | } |
268 | } |
269 | return std::make_pair(tensor_ids, ivalues); |
270 | #else |
271 | TORCH_CHECK( |
272 | false, "TorchScript backend not yet supported in FBCODE builds" ); |
273 | return std::make_pair( |
274 | std::vector<int64_t>(), std::vector<at::IValue>()); |
275 | #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) |
276 | }); |
277 | // TODO(shunting) revisit this part for XLA |
278 | lazy_ts_backend.def( |
279 | "_run_cached_graph" , |
280 | [](const std::string& hash_str, |
281 | const std::vector<at::IValue>& graph_inputs) { |
282 | std::vector<at::Tensor> result; |
283 | #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) |
284 | TORCH_CHECK(hash_str.size() == sizeof(hash_t)); |
285 | hash_t hash = *(hash_t*)(hash_str.c_str()); |
286 | auto cachedComputation = |
287 | LazyGraphExecutor::Get()->GetComputationCache()->Get(hash); |
288 | TORCH_CHECK( |
289 | cachedComputation, |
290 | "Failed to get computation by hash. Maybe the entry get kicked out of the LRU cache" ); // TODO implement a fallback mechanism, or make sure those entries never get kicked out |
291 | auto computationPtr = |
292 | (torch::lazy::TSComputation*)cachedComputation->computation.get(); |
293 | |
294 | std::vector<torch::jit::IValue> stack; |
295 | stack.reserve(graph_inputs.size()); |
296 | for (const auto& arg : graph_inputs) { |
297 | stack.emplace_back(arg); |
298 | } |
299 | computationPtr->graph_executor().run(stack); |
300 | result.reserve(stack.size()); |
301 | for (torch::jit::IValue elem : stack) { |
302 | result.push_back(elem.toTensor()); |
303 | } |
304 | #else |
305 | TORCH_CHECK( |
306 | false, "TorchScript backend not yet supported in FBCODE builds" ); |
307 | #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) |
308 | return result; |
309 | }); |
310 | |
311 | // GetPythonFramesFunction() has not ever worked with torchdeploy/multipy |
312 | // possibly becuase GetPythonFrames resolves to external cpython rather |
313 | // than embedded cpython. So far this problem has only been observed |
314 | // internally, so we will just block it off there. |
315 | |
316 | #if !(defined(USE_DEPLOY)) |
317 | |
318 | // When libtorch_python is loaded, we register the python frame getter |
319 | // otherwise, debug util simply omits python frames |
320 | GetPythonFramesFunction() = GetPythonFrames; |
321 | |
322 | #endif // USE_DEPLOY |
323 | } |
324 | |
325 | } // namespace lazy |
326 | } // namespace torch |
327 | |