1 | #pragma once |
---|---|
2 | |
3 | #include <torch/csrc/distributed/c10d/comm.hpp> |
4 | |
5 | #include <ATen/ATen.h> |
6 | #include <ATen/core/ivalue.h> |
7 | #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
8 | #include <torch/csrc/utils/pybind.h> |
9 | |
10 | namespace c10d { |
11 | |
12 | class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { |
13 | public: |
14 | // Takes a state and a callable hook. The inputs are Python objects. |
15 | // The state is passed to the hook in runHook method, and it can be used to |
16 | // maintain and update any state information during the execution of the hook. |
17 | // The hook performs user-specified processing and returns a future indicating |
18 | // asychronous communication of gradients. |
19 | PythonCommHook(py::object state, py::object hook) |
20 | : state_(std::move(state)), hook_(std::move(hook)) {} |
21 | |
22 | ~PythonCommHook() override; |
23 | |
24 | c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override; |
25 | |
26 | at::Tensor parseHookResult(const c10::IValue& result) override; |
27 | |
28 | private: |
29 | // Only needed for stateful communication. |
30 | py::object state_; |
31 | py::object hook_; |
32 | }; |
33 | |
34 | } // namespace c10d |
35 |