1 | #include <torch/csrc/distributed/c10d/python_comm_hook.h> |
---|---|
2 | |
3 | #include <ATen/core/functional.h> |
4 | #include <torch/csrc/distributed/c10d/reducer.hpp> |
5 | #include <torch/csrc/jit/python/pybind_utils.h> |
6 | #include <torch/csrc/utils/tensor_flatten.h> |
7 | |
8 | namespace c10d { |
9 | |
10 | PythonCommHook::~PythonCommHook() { |
11 | py::gil_scoped_acquire ag; |
12 | state_.dec_ref(); |
13 | hook_.dec_ref(); |
14 | // Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor |
15 | // to decref on the PyObject again. |
16 | // See Note [Destructing py::object] in python_ivalue.h |
17 | state_.ptr() = nullptr; |
18 | hook_.ptr() = nullptr; |
19 | } |
20 | |
21 | c10::intrusive_ptr<c10::ivalue::Future> PythonCommHook::runHook( |
22 | GradBucket& bucket) { |
23 | py::gil_scoped_acquire acquire; |
24 | |
25 | py::object py_fut = hook_(state_, bucket); |
26 | |
27 | try { |
28 | return py_fut.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>()->fut; |
29 | } catch (const py::cast_error& e) { |
30 | auto type = py_fut.get_type(); |
31 | auto errMsg = c10::str( |
32 | e.what(), |
33 | ". DDP communication hook's callback must return a " |
34 | "torch.futures.Future object, but got ", |
35 | type.attr("__module__").cast<std::string>(), |
36 | ".", |
37 | type.attr("__qualname__").cast<std::string>()); |
38 | TORCH_CHECK(false, errMsg); |
39 | } |
40 | } |
41 | |
42 | at::Tensor PythonCommHook::parseHookResult(const c10::IValue& result) { |
43 | TORCH_INTERNAL_ASSERT( |
44 | result.isPyObject(), "expected the hook result is a PyObject"); |
45 | |
46 | py::gil_scoped_acquire ag; |
47 | py::object obj = torch::jit::toPyObject(result); |
48 | auto value = torch::jit::toIValue(obj, c10::TensorType::get()); |
49 | return value.toTensor(); |
50 | } |
51 | |
52 | } // namespace c10d |
53 |