1#pragma once
2
3#include <torch/csrc/distributed/c10d/comm.hpp>
4
5#include <ATen/ATen.h>
6#include <ATen/core/ivalue.h>
7#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
8#include <torch/csrc/utils/pybind.h>
9
10namespace c10d {
11
12class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
13 public:
14 // Takes a state and a callable hook. The inputs are Python objects.
15 // The state is passed to the hook in runHook method, and it can be used to
16 // maintain and update any state information during the execution of the hook.
17 // The hook performs user-specified processing and returns a future indicating
18 // asychronous communication of gradients.
19 PythonCommHook(py::object state, py::object hook)
20 : state_(std::move(state)), hook_(std::move(hook)) {}
21
22 ~PythonCommHook() override;
23
24 c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
25
26 at::Tensor parseHookResult(const c10::IValue& result) override;
27
28 private:
29 // Only needed for stateful communication.
30 py::object state_;
31 py::object hook_;
32};
33
34} // namespace c10d
35