1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/rpc/rref_impl.h> |
4 | #include <torch/csrc/python_headers.h> |
5 | #include <torch/csrc/utils/pybind.h> |
6 | |
7 | namespace torch { |
8 | namespace distributed { |
9 | namespace rpc { |
10 | |
11 | enum RRefProxyType { RPC_SYNC, RPC_ASYNC, REMOTE }; |
12 | |
13 | // Python wrapper of an RRef shared_ptr that supports Python |
14 | // pickle and unpickle. |
15 | class PYBIND11_EXPORT PyRRef { |
16 | public: |
17 | // The first ctor can only be called while holding GIL. See its implementation |
18 | // for more explanations. |
19 | explicit PyRRef(const py::object& value, const py::object& type_hint); |
20 | explicit PyRRef(c10::intrusive_ptr<RRef> rref); |
21 | ~PyRRef(); |
22 | |
23 | bool isOwner() const; |
24 | bool confirmedByOwner() const; |
25 | WorkerInfo owner() const; |
26 | std::string ownerName() const; |
27 | py::object toHere( |
28 | const float timeoutSeconds = |
29 | torch::distributed::rpc::kUnsetRpcTimeout) const; |
30 | py::object localValue() const; |
31 | std::string str() const; |
32 | py::tuple pickle() const; |
33 | static PyRRef unpickle(const py::tuple& t); |
34 | c10::IValue toIValue() const; |
35 | // Future that is associated with the creation of this RRef on the remote end. |
36 | // This is only used to get the future corresponding to the rref for profiling |
37 | // use cases. |
38 | c10::intrusive_ptr<JitFuture> getFuture() const; |
39 | // Keeps track of the future responsible for profiling owner creation |
40 | // acknowledgement |
41 | c10::intrusive_ptr<JitFuture> getProfilingFuture() const; |
42 | // Sets the future responsible for profiling owner creation acknowledgement. |
43 | // This future is set from python to be a future that returns when profiling |
44 | // callbacks have been run. |
45 | void setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture); |
46 | |
47 | // create a proxy on this RRef, which can be used to launch RPC on the owner |
48 | // of this RRef to run functions on the object referenced by this RRef. |
49 | py::object createRRefProxy( |
50 | const RRefProxyType& mode, |
51 | float timeoutSeconds = rpc::kUnsetRpcTimeout) const; |
52 | |
53 | // get the type of the data object referenced by this RRef. Timeout argument |
54 | // is only used in the first invocation of this function as an argument to the |
55 | // RPC to the owner node of the RRef. |
56 | py::object getRRefType( |
57 | float timeout = rpc::kUnsetRpcTimeout, |
58 | bool blocking = true); |
59 | |
60 | // Run the backward pass with the RRef as the root. |
61 | void backward(int64_t autogradContextId, bool retainGraph); |
62 | |
63 | // Helper static function to run backward on a given rref. |
64 | static void backward( |
65 | int64_t autogradContextId, |
66 | bool retainGraph, |
67 | const c10::intrusive_ptr<RRef>& rref); |
68 | |
69 | // Specialization of backward if the rref is an OwnerRRef. |
70 | static void backwardOwnerRRef( |
71 | int64_t autogradContextId, |
72 | bool retainGraph, |
73 | IValue value); |
74 | |
75 | private: |
76 | c10::intrusive_ptr<RRef> rref_; |
77 | c10::optional<c10::intrusive_ptr<JitFuture>> profilingFuture_; |
78 | c10::optional<py::object> type_; |
79 | }; |
80 | |
81 | } // namespace rpc |
82 | } // namespace distributed |
83 | } // namespace torch |
84 | |