1#pragma once
2
3#include <torch/csrc/distributed/rpc/rref_impl.h>
4#include <torch/csrc/python_headers.h>
5#include <torch/csrc/utils/pybind.h>
6
7namespace torch {
8namespace distributed {
9namespace rpc {
10
11enum RRefProxyType { RPC_SYNC, RPC_ASYNC, REMOTE };
12
13// Python wrapper of an RRef shared_ptr that supports Python
14// pickle and unpickle.
15class PYBIND11_EXPORT PyRRef {
16 public:
17 // The first ctor can only be called while holding GIL. See its implementation
18 // for more explanations.
19 explicit PyRRef(const py::object& value, const py::object& type_hint);
20 explicit PyRRef(c10::intrusive_ptr<RRef> rref);
21 ~PyRRef();
22
23 bool isOwner() const;
24 bool confirmedByOwner() const;
25 WorkerInfo owner() const;
26 std::string ownerName() const;
27 py::object toHere(
28 const float timeoutSeconds =
29 torch::distributed::rpc::kUnsetRpcTimeout) const;
30 py::object localValue() const;
31 std::string str() const;
32 py::tuple pickle() const;
33 static PyRRef unpickle(const py::tuple& t);
34 c10::IValue toIValue() const;
35 // Future that is associated with the creation of this RRef on the remote end.
36 // This is only used to get the future corresponding to the rref for profiling
37 // use cases.
38 c10::intrusive_ptr<JitFuture> getFuture() const;
39 // Keeps track of the future responsible for profiling owner creation
40 // acknowledgement
41 c10::intrusive_ptr<JitFuture> getProfilingFuture() const;
42 // Sets the future responsible for profiling owner creation acknowledgement.
43 // This future is set from python to be a future that returns when profiling
44 // callbacks have been run.
45 void setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture);
46
47 // create a proxy on this RRef, which can be used to launch RPC on the owner
48 // of this RRef to run functions on the object referenced by this RRef.
49 py::object createRRefProxy(
50 const RRefProxyType& mode,
51 float timeoutSeconds = rpc::kUnsetRpcTimeout) const;
52
53 // get the type of the data object referenced by this RRef. Timeout argument
54 // is only used in the first invocation of this function as an argument to the
55 // RPC to the owner node of the RRef.
56 py::object getRRefType(
57 float timeout = rpc::kUnsetRpcTimeout,
58 bool blocking = true);
59
60 // Run the backward pass with the RRef as the root.
61 void backward(int64_t autogradContextId, bool retainGraph);
62
63 // Helper static function to run backward on a given rref.
64 static void backward(
65 int64_t autogradContextId,
66 bool retainGraph,
67 const c10::intrusive_ptr<RRef>& rref);
68
69 // Specialization of backward if the rref is an OwnerRRef.
70 static void backwardOwnerRRef(
71 int64_t autogradContextId,
72 bool retainGraph,
73 IValue value);
74
75 private:
76 c10::intrusive_ptr<RRef> rref_;
77 c10::optional<c10::intrusive_ptr<JitFuture>> profilingFuture_;
78 c10::optional<py::object> type_;
79};
80
81} // namespace rpc
82} // namespace distributed
83} // namespace torch
84