1#include <torch/csrc/lazy/python/init.h>
2
3#include <ATen/FunctionalTensorWrapper.h>
4#include <c10/core/Device.h>
5#include <torch/csrc/jit/python/pybind.h>
6#include <torch/csrc/lazy/backend/backend_device.h>
7#include <torch/csrc/lazy/backend/backend_interface.h>
8#include <torch/csrc/lazy/core/config.h>
9#include <torch/csrc/lazy/core/debug_util.h>
10#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
11#include <torch/csrc/lazy/core/ir_dump_util.h>
12#include <torch/csrc/lazy/core/lazy_graph_executor.h>
13#include <torch/csrc/lazy/core/metrics.h>
14#include <torch/csrc/lazy/core/trie.h>
15#include <torch/csrc/lazy/python/python_util.h>
16#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
17#include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
18#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
19#endif // FBCODE_CAFFE2 || OVRSOURCE
20#include <string>
21#include <vector>
22
23namespace torch {
24namespace lazy {
25
26// TODO(whc) backend 'device' related APIs are not very clear, this code could
27// be simplified but it should probably be done together with
28// designing/refactoring the overall approach to get/set of default eager/lazy
29// device types
30torch::lazy::BackendDevice GetDeviceOrCurrent(const std::string& device_str) {
31 if (device_str.empty()) {
32 getBackend()->GetDefaultDeviceType();
33 return torch::lazy::BackendDevice();
34 }
35 return torch::lazy::atenDeviceToBackendDevice(c10::Device(device_str));
36}
37
38std::ptrdiff_t GetTensorId(const at::Tensor& tensor) {
39 torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
40 return lazy_tensor->GetUniqueId();
41}
42
43std::string GetTensorsDump(
44 const std::vector<at::Tensor>& tensors,
45 const std::function<std::string(c10::ArrayRef<const torch::lazy::Node*>)>&
46 coverter) {
47 std::vector<const torch::lazy::Node*> nodes;
48 std::vector<torch::lazy::Value> values;
49 for (auto& tensor : tensors) {
50 auto inner = at::functionalization::impl::from_functional_tensor(tensor);
51 torch::lazy::LazyTensorPtr lazy_tensor =
52 torch::lazy::TryGetLtcTensor(inner);
53 values.push_back(lazy_tensor->GetIrValue());
54 nodes.push_back(values.back().node.get());
55 }
56 return coverter(nodes);
57}
58
59std::vector<torch::lazy::LazyTensorPtr> GetLtcTensors(
60 const std::vector<at::Tensor>& tensors,
61 bool want_all) {
62 std::vector<torch::lazy::LazyTensorPtr> lazy_tensors;
63 lazy_tensors.reserve(tensors.size());
64 if (want_all) {
65 for (auto& tensor : tensors) {
66 lazy_tensors.push_back(torch::lazy::TryGetLtcTensor(tensor));
67 }
68 } else {
69 for (auto& tensor : tensors) {
70 auto lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
71 if (lazy_tensor) {
72 lazy_tensors.push_back(lazy_tensor);
73 }
74 }
75 }
76 return lazy_tensors;
77}
78
79std::string GetTensorsBackendGraph(const std::vector<at::Tensor>& tensors) {
80 std::vector<torch::lazy::LazyTensorPtr> lazy_tensors =
81 GetLtcTensors(tensors, /*want_all=*/false);
82 return torch::lazy::LazyGraphExecutor::Get()->DumpBackendComputation(
83 lazy_tensors);
84}
85
86void SyncTensors(
87 const std::vector<at::Tensor>& tensors,
88 const std::vector<std::string>& devices,
89 bool wait,
90 bool sync_ltc_data) {
91 std::vector<torch::lazy::LazyTensorPtr> lazy_tensors =
92 GetLtcTensors(tensors, /*want_all=*/false);
93 torch::lazy::LazyGraphExecutor::Get()->SyncTensorsGraph(
94 &lazy_tensors, devices, wait, sync_ltc_data);
95}
96
97void initLazyBindings(PyObject* module) {
98 auto m = py::handle(module).cast<py::module>();
99 auto lazy = m.def_submodule("_lazy");
100 auto lazy_ts_backend = m.def_submodule("_lazy_ts_backend");
101
102 lazy.def(
103 "_mark_step",
104 // TODO(whc) this API should probably change from vector<string> to
105 // vector<c10::device> but in a separate PR
106 [](const std::string& device_str,
107 const std::vector<std::string>& devices,
108 bool wait) {
109 pybind11::gil_scoped_release no_gil;
110 auto backend_device = GetDeviceOrCurrent(device_str);
111 torch::lazy::LazyGraphExecutor::Get()->SyncLiveTensorsGraph(
112 &backend_device, devices, wait);
113 torch::lazy::LazyGraphExecutor::Get()->MarkStep(backend_device);
114 },
115 py::arg("device") = "",
116 py::arg("devices"),
117 py::arg("wait") = true);
118 lazy.def(
119 "_wait_device_ops",
120 [](const std::vector<std::string>& devices) {
121 pybind11::gil_scoped_release no_gil;
122 // TODO: Add support of non-empty devices.
123 if (!devices.empty()) {
124 LOG(ERROR) << "Non-empty devices are not supported.";
125 }
126 torch::lazy::LazyGraphExecutor::Get()->WaitDeviceOps({});
127 },
128 py::arg("devices"));
129 lazy.def("_reset_metrics", []() {
130 torch::lazy::MetricsArena::Get()->ResetCounters();
131 torch::lazy::MetricsArena::Get()->ResetMetrics();
132 });
133 lazy.def("_counter_names", []() { return torch::lazy::GetCounterNames(); });
134 lazy.def(
135 "_metrics_report", []() { return torch::lazy::CreateMetricReport(); });
136 lazy.def("_counter_value", [](const std::string& name) -> py::object {
137 torch::lazy::CounterData* data = torch::lazy::GetCounter(name);
138 return data != nullptr ? py::cast<int64_t>(data->Value()) : py::none();
139 });
140 lazy.def("_get_tensor_id", [](const at::Tensor& tensor) {
141 return GetTensorId(tensor);
142 });
143
144 lazy.def(
145 "_get_tensors_text",
146 [](const std::vector<at::Tensor>& tensors) -> std::string {
147 auto coverter = [](c10::ArrayRef<const torch::lazy::Node*> nodes) {
148 return torch::lazy::DumpUtil::ToText(nodes);
149 };
150 return GetTensorsDump(tensors, coverter);
151 });
152 lazy.def(
153 "_get_tensors_dot",
154 [](const std::vector<at::Tensor>& tensors) -> std::string {
155 auto coverter = [](c10::ArrayRef<const torch::lazy::Node*> nodes) {
156 return torch::lazy::DumpUtil::ToDot(nodes);
157 };
158 return GetTensorsDump(tensors, coverter);
159 });
160 lazy.def(
161 "_get_tensors_backend",
162 [](const std::vector<at::Tensor>& tensors) -> std::string {
163 return GetTensorsBackendGraph(tensors);
164 });
165 lazy.def("_get_graph_hash", [](const std::vector<at::Tensor>& tensors) {
166 std::vector<LazyTensorPtr> xtensors;
167 xtensors.reserve(tensors.size());
168 for (auto& tensor : tensors) {
169 xtensors.emplace_back(TryGetLtcTensor(tensor));
170 }
171 auto hash = LazyGraphExecutor::Get()->GetGraphHash(xtensors);
172 std::string bin((const char*)&hash, sizeof(hash));
173 return py::bytes(bin);
174 });
175 lazy.def(
176 "_sync_multi",
177 [](const std::vector<at::Tensor>& tensors,
178 const std::vector<std::string>& devices,
179 bool wait,
180 bool sync_ltc_data) {
181 pybind11::gil_scoped_release no_gil;
182 SyncTensors(tensors, devices, wait, sync_ltc_data);
183 },
184 py::arg("tensors"),
185 py::arg("devices"),
186 py::arg("wait") = true,
187 py::arg("sync_ltc_data") = true);
188
189 lazy.def("_get_force_fallback", []() {
190 return torch::lazy::getLTCForceFallback();
191 });
192 lazy.def("_set_force_fallback", [](std::string newval) {
193 torch::lazy::getLTCForceFallback() = newval;
194 });
195 lazy.def("_clear_ir_cache", []() { TrieCache::Get()->Clear(); });
196 lazy.def("_dump_ir_cache", [](std::string filename) {
197 TrieCache::Get()->DumpToDotFile(filename);
198 });
199 lazy.def("_set_reuse_ir", [](bool val) { FLAGS_torch_lazy_reuse_ir = val; });
200 lazy.def("_set_symbolic_shape_mode", [](bool val) {
201 FLAGS_ltc_enable_symbolic_shapes = val;
202 });
203 lazy.def("_get_symbolic_shape_mode", []() {
204 return FLAGS_ltc_enable_symbolic_shapes;
205 });
206 lazy.def("_get_default_device_type", []() {
207 return getBackend()->GetDefaultDeviceType()->toString();
208 });
209
210 lazy_ts_backend.def("_init", []() {
211#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
212 torch::lazy::InitTorchScriptBackend();
213#else
214 TORCH_CHECK(false, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds");
215#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
216 });
217
218 /*
219 * Return tensor ids and tensors for DeviceData nodes.
220 * TODO(shunting) revisit this API for XLA
221 */
222 lazy_ts_backend.def(
223 "_get_tensors_ts_device_data_node",
224 [](const std::vector<at::Tensor>& tensors)
225 -> std::pair<std::vector<int64_t>, std::vector<at::IValue>> {
226#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
227 std::vector<const Node*> roots;
228 for (auto& tensor : tensors) {
229 auto xtensor = TryGetLtcTensor(tensor);
230 roots.push_back(xtensor->GetIrValue().node.get());
231 }
232 auto post_order = Util::ComputePostOrder(roots);
233 std::vector<int64_t> tensor_ids;
234 std::vector<at::IValue> ivalues;
235
236 std::unordered_set<BackendData::Handle> data_handles_;
237 for (auto nodeptr : post_order) {
238 if (nodeptr->op() == *torch::lazy::ltc_device_data) {
239 const auto backend_data =
240 getBackend()->GetComputationDataFromNode(nodeptr);
241
242 auto infoptr = backend_data->info();
243 auto deviceDataInfoPtr =
244 (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
245 auto* tsDataPtr = (torch::lazy::TSData*)backend_data.get();
246
247 // dedup DeviceData by handle
248 auto handle = tsDataPtr->GetHandle();
249 if (!data_handles_.insert(handle).second) {
250 continue;
251 }
252 tensor_ids.push_back(deviceDataInfoPtr->tensor_id);
253 /*
254 * If the TSData contains a tensor, then the tensor id will uniquely
255 * identify the tensor. We use that tensor id to find the tensor in
256 * other places: e.g. in the python forward method parameters.
257 *
258 * If the TSData contains a scalar, the tensor id itself is not
259 * important. We reuse the scalar value in future calls.
260 */
261 if (tsDataPtr->HasValue()) {
262 ivalues.emplace_back(tsDataPtr->data());
263 } else {
264 CHECK(tsDataPtr->scalar.has_value());
265 ivalues.emplace_back(tsDataPtr->scalar.value());
266 }
267 }
268 }
269 return std::make_pair(tensor_ids, ivalues);
270#else
271 TORCH_CHECK(
272 false, "TorchScript backend not yet supported in FBCODE builds");
273 return std::make_pair(
274 std::vector<int64_t>(), std::vector<at::IValue>());
275#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
276 });
277 // TODO(shunting) revisit this part for XLA
278 lazy_ts_backend.def(
279 "_run_cached_graph",
280 [](const std::string& hash_str,
281 const std::vector<at::IValue>& graph_inputs) {
282 std::vector<at::Tensor> result;
283#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
284 TORCH_CHECK(hash_str.size() == sizeof(hash_t));
285 hash_t hash = *(hash_t*)(hash_str.c_str());
286 auto cachedComputation =
287 LazyGraphExecutor::Get()->GetComputationCache()->Get(hash);
288 TORCH_CHECK(
289 cachedComputation,
290 "Failed to get computation by hash. Maybe the entry get kicked out of the LRU cache"); // TODO implement a fallback mechanism, or make sure those entries never get kicked out
291 auto computationPtr =
292 (torch::lazy::TSComputation*)cachedComputation->computation.get();
293
294 std::vector<torch::jit::IValue> stack;
295 stack.reserve(graph_inputs.size());
296 for (const auto& arg : graph_inputs) {
297 stack.emplace_back(arg);
298 }
299 computationPtr->graph_executor().run(stack);
300 result.reserve(stack.size());
301 for (torch::jit::IValue elem : stack) {
302 result.push_back(elem.toTensor());
303 }
304#else
305 TORCH_CHECK(
306 false, "TorchScript backend not yet supported in FBCODE builds");
307#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
308 return result;
309 });
310
311 // GetPythonFramesFunction() has not ever worked with torchdeploy/multipy
312 // possibly becuase GetPythonFrames resolves to external cpython rather
313 // than embedded cpython. So far this problem has only been observed
314 // internally, so we will just block it off there.
315
316#if !(defined(USE_DEPLOY))
317
318 // When libtorch_python is loaded, we register the python frame getter
319 // otherwise, debug util simply omits python frames
320 GetPythonFramesFunction() = GetPythonFrames;
321
322#endif // USE_DEPLOY
323}
324
325} // namespace lazy
326} // namespace torch
327