1#include <ATen/record_function.h>
2#include <torch/csrc/distributed/autograd/autograd.h>
3
4namespace torch {
5namespace distributed {
6namespace autograd {
7
8constexpr auto kDistAutogradBackwardProfilingKey =
9 "torch::distributed::autograd::backward";
10
11void backward(
12 int64_t context_id,
13 const variable_list& roots,
14 bool retain_graph) {
15 C10_LOG_API_USAGE_ONCE("torch.distributed.autograd.backward");
16 RECORD_FUNCTION(
17 kDistAutogradBackwardProfilingKey, std::vector<c10::IValue>());
18 try {
19 DistEngine::getInstance().execute(context_id, roots, retain_graph);
20 } catch (std::exception& e) {
21 // FIXME: crashes if exception type is not RuntimeError
22 TORCH_CHECK(false, e.what());
23 }
24}
25
26} // namespace autograd
27} // namespace distributed
28} // namespace torch
29