1#include <torch/csrc/distributed/rpc/py_rref.h>
2
3#include <torch/csrc/autograd/autograd.h>
4#include <torch/csrc/distributed/autograd/autograd.h>
5#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
6#include <torch/csrc/distributed/rpc/python_functions.h>
7#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
8#include <torch/csrc/distributed/rpc/rref_context.h>
9#include <torch/csrc/jit/python/module_python.h>
10#include <torch/csrc/jit/python/pybind_utils.h>
11
12namespace torch {
13namespace distributed {
14namespace rpc {
15
16///////////////////// Pickle/Unpickle Helplers ////////////////////////////
17
18namespace {
19
20py::tuple toPyTuple(const RRefForkData& rrefForkData) {
21 // add GIL as it is contructing a py::object
22 pybind11::gil_scoped_acquire ag;
23 return py::make_tuple(
24 rrefForkData.ownerId_,
25 rrefForkData.rrefId_.createdOn_,
26 rrefForkData.rrefId_.localId_,
27 rrefForkData.forkId_.createdOn_,
28 rrefForkData.forkId_.localId_,
29 rrefForkData.parent_,
30 rrefForkData.typeStr_);
31}
32
33RRefForkData fromPyTuple(const py::tuple& pyTuple) {
34 // add GIL as it is accessing a py::object
35 pybind11::gil_scoped_acquire ag;
36 TORCH_INTERNAL_ASSERT(
37 pyTuple.size() == RFD_TUPLE_SIZE,
38 "Pickled RRefForkData must contain ",
39 RFD_TUPLE_SIZE,
40 " numbers.");
41 worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
42 // const reference will extend the lifetime of the temporary variable
43 const RRefId& rrefId = RRefId(
44 pyTuple[RREFID_ON_IDX].cast<worker_id_t>(),
45 pyTuple[RREFID_ID_IDX].cast<local_id_t>());
46 const RRefId& forkId = RRefId(
47 pyTuple[FORKID_ON_IDX].cast<worker_id_t>(),
48 pyTuple[FORKID_ID_IDX].cast<local_id_t>());
49
50 worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>();
51 const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>();
52
53 return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
54}
55
56TypePtr tryInferTypeWithTypeHint(
57 const py::object& value,
58 const py::object& type_hint) {
59 // If the py::object to be contained by the RRef is a ScriptModule, we enforce
60 // users to specify its ModuleInterface type.
61 if (auto module = jit::as_module(value)) {
62 TORCH_CHECK(
63 !type_hint.is_none(),
64 "The RRef being created contains a ScriptModule, "
65 "must provide its ModuleInterface type hint. ");
66 c10::QualifiedName type_qualified_name = c10::QualifiedName(
67 py::cast<std::string>(py::module::import("torch._jit_internal")
68 .attr("_qualified_name")(type_hint)));
69 TypePtr type_hint_ptr =
70 jit::get_python_cu()->get_interface(type_qualified_name);
71 std::ostringstream subtype_check_msg;
72 TORCH_CHECK(
73 type_hint_ptr != nullptr &&
74 module.value().type()->isSubtypeOfExt(
75 *type_hint_ptr, &subtype_check_msg),
76 module.value().type()->repr_str(),
77 " is not a subtype of the type hint: ",
78 type_qualified_name.qualifiedName(),
79 ", did you pass a valid interface type?\n",
80 subtype_check_msg.str());
81 return type_hint_ptr;
82 } else {
83 TORCH_CHECK(
84 type_hint.is_none(),
85 "type_hint should only be specified when the RRef being created contains a ScriptModule.");
86 }
87
88 // Check if value is an instance of a ScriptClass. If not, skip type inference
89 // because it will try to script the class that value is in instance of, and
90 // this should be avoided.
91 py::bool_ can_compile = py::module::import("torch._jit_internal")
92 .attr("can_compile_class")(value.get_type());
93
94 if (py::cast<bool>(can_compile)) {
95 py::object existing_ty = py::module::import("torch.jit._state")
96 .attr("_get_script_class")(value.get_type());
97
98 if (existing_ty.is_none()) {
99 return PyObjectType::get();
100 }
101 }
102
103 // NB: `jit::tryToInferType(..)` infers types including ScriptClass, but
104 // excluding ScriptModule.
105 jit::InferredType type_inferred = jit::tryToInferType(value);
106 if (type_inferred.success()) {
107 // If we could infer the type from the pyobject, we create
108 // the RRef with the IValue of that type.
109 return type_inferred.type();
110 }
111
112 // Otherwise it's a pure pyobject, create the RRef
113 // that holds an IValue of an pyobject.
114 return PyObjectType::get();
115}
116
117} // namespace
118
119/////////////////////////// PyRRef //////////////////////////////////
120
121PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref)
122 : rref_(std::move(rref)), profilingFuture_(c10::nullopt) {
123 TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
124 C10_LOG_API_USAGE_ONCE("torch.distributed.rref");
125}
126
127PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
128 : PyRRef([&value, &type_hint]() mutable {
129 TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
130 auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
131 // jit::toIValue takes a py::handle as the first argument, and it calls
132 // py::handle.cast<py::object>() to incref of provided value. The
133 // returned ivalue will keep the reference alive.
134 // NB: the first argument const py::object& value must be kept alive
135 // until the following jit::toIValue returns (i.e., incref done). That's
136 // why this ctor can only be called while holding GIL.
137 IValue ivalue = jit::toIValue(value, elem_type);
138 rref->setValue(std::move(ivalue));
139 return rref;
140 }()) {}
141
142PyRRef::~PyRRef() {
143 if (type_.has_value()) {
144 (*type_).dec_ref();
145 // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
146 // decref on the PyObject again.
147 // See Note [Destructing py::object] in python_ivalue.h
148 (*type_).ptr() = nullptr;
149 }
150}
151
152c10::intrusive_ptr<JitFuture> PyRRef::getFuture() const {
153 // Marking hasValue to false, as this Future is only used for signaling
154 // profiler to update profiling result and the profiler does not retrieve
155 // any value from it.
156 return toPyJitFuture(rref_->getOwnerCreationFuture(), false /* hasValue */);
157}
158
159c10::intrusive_ptr<JitFuture> PyRRef::getProfilingFuture() const {
160 TORCH_INTERNAL_ASSERT(profilingFuture_, "Profiling future has not been set!");
161 return *profilingFuture_;
162}
163
164void PyRRef::setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture) {
165 profilingFuture_ = std::move(profilingFuture);
166}
167
168bool PyRRef::isOwner() const {
169 return rref_->isOwner();
170}
171
172bool PyRRef::confirmedByOwner() const {
173 return rref_->confirmedByOwner();
174}
175
176WorkerInfo PyRRef::owner() const {
177 return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner());
178}
179
180std::string PyRRef::ownerName() const {
181 return rref_->ownerName();
182}
183
184py::object PyRRef::toHere(const float timeoutSeconds) const {
185 C10_LOG_API_USAGE_ONCE("torch.distributed.rref.to_here");
186 if (rref_->isOwner()) {
187 return localValue();
188 } else {
189 // toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
190 // a python object
191 IValue value = c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere(
192 timeoutSeconds);
193
194 if (rref_->isPyObj()) {
195 // python_rpc_handler deserialization will acquires GIL.
196 auto rfr_values = value.toTupleRef().elements().vec();
197 auto& pythonRpcHandler = PythonRpcHandler::getInstance();
198 auto ret = pythonRpcHandler.deserialize(
199 SerializedPyObj::fromIValues(std::move(rfr_values)));
200 pythonRpcHandler.handleException(ret);
201 return ret;
202 } else {
203 // acquiring GIL as torch::jit::toPyObject creates new py::object
204 // without grabbing the GIL.
205 pybind11::gil_scoped_acquire ag;
206 return torch::jit::toPyObject(std::move(value));
207 }
208 }
209}
210
211py::object PyRRef::localValue() const {
212 TORCH_CHECK(
213 rref_->isOwner(),
214 "For ",
215 *rref_,
216 ", can't call localValue() on user ",
217 RRefContext::getInstance().agent()->getWorkerInfo(),
218 ". Call it on owner ",
219 owner());
220
221 py::object res;
222 auto value =
223 c10::static_intrusive_pointer_cast<const OwnerRRef>(rref_)->getValue();
224 auto& rpcHandler = PythonRpcHandler::getInstance();
225 {
226 // acquiring GIL as torch::jit::toPyObject creates new py::object without
227 // grabbing the GIL.
228 pybind11::gil_scoped_acquire ag;
229 res = torch::jit::toPyObject(std::move(value));
230 rpcHandler.handleExceptionGILHeld(res);
231 }
232 return res;
233}
234
235std::string PyRRef::str() const {
236 if (rref_->isOwner()) {
237 return c10::str("OwnerRRef(", rref_->rrefId(), ")");
238 } else {
239 return c10::str(
240 "UserRRef(RRefId = ",
241 rref_->rrefId(),
242 ", ForkId = ",
243 c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId(),
244 ")");
245 }
246}
247
248py::object PyRRef::createRRefProxy(
249 const RRefProxyType& type,
250 float timeoutSeconds) const {
251 auto& pythonRpcHandler = PythonRpcHandler::getInstance();
252 pybind11::gil_scoped_acquire ag;
253 auto& functions = pythonRpcHandler.getRRefProxyFunctions();
254 auto& ctor = functions.rrefProxyCtor_;
255 switch (type) {
256 case RRefProxyType::RPC_SYNC: {
257 return ctor(*this, functions.rpcSync_, timeoutSeconds);
258 }
259 case RRefProxyType::RPC_ASYNC: {
260 return ctor(*this, functions.rpcAsync_, timeoutSeconds);
261 }
262 case RRefProxyType::REMOTE: {
263 return ctor(*this, functions.remote_, timeoutSeconds);
264 }
265 default: {
266 TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type);
267 }
268 }
269}
270
271py::object PyRRef::getRRefType(float timeout, bool blocking) {
272 // GIL is not released when calling this function.
273 if (!type_.has_value()) {
274 pybind11::gil_scoped_release release;
275 auto& pythonRpcHandler = PythonRpcHandler::getInstance();
276 auto& typeFuncs = pythonRpcHandler.getRRefTypeFunctions();
277 pybind11::gil_scoped_acquire acquire;
278 type_ = isOwner() ? typeFuncs.onOwner_(*this, blocking)
279 : typeFuncs.onUser_(*this, timeout, blocking);
280 }
281 // Returns py::object that can be Python type or future.
282 return *type_;
283}
284
285py::tuple PyRRef::pickle() const {
286 auto& ctx = RRefContext::getInstance();
287 auto rrefForkData = ctx.prepareChildFork(rref_);
288 return toPyTuple(rrefForkData);
289}
290
291PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
292 auto& ctx = RRefContext::getInstance();
293 auto rrefForkData = fromPyTuple(pyTuple);
294 TypePtr rrefType =
295 PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
296 c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
297 ctx.notifyOwnerAndParentOfFork(
298 rrefForkData.forkId_, rrefForkData.parent_, rref);
299 return PyRRef(std::move(rref));
300}
301
302c10::IValue PyRRef::toIValue() const {
303 // cast to RRefInterface to hold it into IValue
304 auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_);
305 return IValue(rrefPtr);
306}
307
308void PyRRef::backward(int64_t autogradContextId, bool retainGraph) {
309 backward(autogradContextId, retainGraph, rref_);
310}
311
312void PyRRef::backwardOwnerRRef(
313 int64_t autogradContextId,
314 bool retainGraph,
315 IValue value) {
316 // If we have a PyObj, retrieve the underlying tensor.
317 if (value.isPyObject()) {
318 py::gil_scoped_acquire gil;
319 py::object obj = torch::jit::toPyObject(value);
320 try {
321 value = torch::jit::toIValue(obj, c10::TensorType::get());
322 } catch (py::cast_error& e) {
323 TORCH_CHECK(false, "RRef should contain a tensor for .backward()");
324 }
325 }
326
327 TORCH_CHECK(value.isTensor(), "RRef should contain a tensor for .backward()");
328 auto root = value.toTensor();
329
330 if (autogradContextId == -1) {
331 torch::autograd::backward({root});
332 } else {
333 torch::distributed::autograd::backward(
334 autogradContextId, {root}, retainGraph);
335 }
336}
337
338void PyRRef::backward(
339 int64_t autogradContextId,
340 bool retainGraph,
341 const c10::intrusive_ptr<RRef>& rref) {
342 if (rref->isOwner()) {
343 backwardOwnerRRef(
344 autogradContextId,
345 retainGraph,
346 c10::static_intrusive_pointer_cast<const OwnerRRef>(rref)->getValue());
347 } else {
348 TORCH_CHECK(
349 autogradContextId != -1,
350 "User RRefs require 'dist_autograd_ctx_id' to be specified");
351
352 autograd::RRefBackwardReq rrefBackwardReq(
353 rref->rrefId(), autogradContextId, retainGraph);
354
355 // Invoke distributed backward remotely.
356 auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
357 rpcAgent
358 ->send(
359 rpcAgent->getWorkerInfo(rref->owner()),
360 std::move(rrefBackwardReq).toMessage())
361 ->waitAndThrow();
362 }
363}
364
365} // namespace rpc
366} // namespace distributed
367} // namespace torch
368