1 | #include <ATen/record_function.h> |
---|---|
2 | #include <torch/csrc/distributed/autograd/autograd.h> |
3 | |
4 | namespace torch { |
5 | namespace distributed { |
6 | namespace autograd { |
7 | |
8 | constexpr auto kDistAutogradBackwardProfilingKey = |
9 | "torch::distributed::autograd::backward"; |
10 | |
11 | void backward( |
12 | int64_t context_id, |
13 | const variable_list& roots, |
14 | bool retain_graph) { |
15 | C10_LOG_API_USAGE_ONCE("torch.distributed.autograd.backward"); |
16 | RECORD_FUNCTION( |
17 | kDistAutogradBackwardProfilingKey, std::vector<c10::IValue>()); |
18 | try { |
19 | DistEngine::getInstance().execute(context_id, roots, retain_graph); |
20 | } catch (std::exception& e) { |
21 | // FIXME: crashes if exception type is not RuntimeError |
22 | TORCH_CHECK(false, e.what()); |
23 | } |
24 | } |
25 | |
26 | } // namespace autograd |
27 | } // namespace distributed |
28 | } // namespace torch |
29 |