1 | #include <torch/csrc/distributed/rpc/py_rref.h> |
2 | |
3 | #include <torch/csrc/autograd/autograd.h> |
4 | #include <torch/csrc/distributed/autograd/autograd.h> |
5 | #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h> |
6 | #include <torch/csrc/distributed/rpc/python_functions.h> |
7 | #include <torch/csrc/distributed/rpc/python_rpc_handler.h> |
8 | #include <torch/csrc/distributed/rpc/rref_context.h> |
9 | #include <torch/csrc/jit/python/module_python.h> |
10 | #include <torch/csrc/jit/python/pybind_utils.h> |
11 | |
12 | namespace torch { |
13 | namespace distributed { |
14 | namespace rpc { |
15 | |
16 | ///////////////////// Pickle/Unpickle Helplers //////////////////////////// |
17 | |
18 | namespace { |
19 | |
20 | py::tuple toPyTuple(const RRefForkData& rrefForkData) { |
21 | // add GIL as it is contructing a py::object |
22 | pybind11::gil_scoped_acquire ag; |
23 | return py::make_tuple( |
24 | rrefForkData.ownerId_, |
25 | rrefForkData.rrefId_.createdOn_, |
26 | rrefForkData.rrefId_.localId_, |
27 | rrefForkData.forkId_.createdOn_, |
28 | rrefForkData.forkId_.localId_, |
29 | rrefForkData.parent_, |
30 | rrefForkData.typeStr_); |
31 | } |
32 | |
33 | RRefForkData fromPyTuple(const py::tuple& pyTuple) { |
34 | // add GIL as it is accessing a py::object |
35 | pybind11::gil_scoped_acquire ag; |
36 | TORCH_INTERNAL_ASSERT( |
37 | pyTuple.size() == RFD_TUPLE_SIZE, |
38 | "Pickled RRefForkData must contain " , |
39 | RFD_TUPLE_SIZE, |
40 | " numbers." ); |
41 | worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>(); |
42 | // const reference will extend the lifetime of the temporary variable |
43 | const RRefId& rrefId = RRefId( |
44 | pyTuple[RREFID_ON_IDX].cast<worker_id_t>(), |
45 | pyTuple[RREFID_ID_IDX].cast<local_id_t>()); |
46 | const RRefId& forkId = RRefId( |
47 | pyTuple[FORKID_ON_IDX].cast<worker_id_t>(), |
48 | pyTuple[FORKID_ID_IDX].cast<local_id_t>()); |
49 | |
50 | worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>(); |
51 | const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>(); |
52 | |
53 | return RRefForkData(ownerId, rrefId, forkId, parent, typeStr); |
54 | } |
55 | |
56 | TypePtr tryInferTypeWithTypeHint( |
57 | const py::object& value, |
58 | const py::object& type_hint) { |
59 | // If the py::object to be contained by the RRef is a ScriptModule, we enforce |
60 | // users to specify its ModuleInterface type. |
61 | if (auto module = jit::as_module(value)) { |
62 | TORCH_CHECK( |
63 | !type_hint.is_none(), |
64 | "The RRef being created contains a ScriptModule, " |
65 | "must provide its ModuleInterface type hint. " ); |
66 | c10::QualifiedName type_qualified_name = c10::QualifiedName( |
67 | py::cast<std::string>(py::module::import("torch._jit_internal" ) |
68 | .attr("_qualified_name" )(type_hint))); |
69 | TypePtr type_hint_ptr = |
70 | jit::get_python_cu()->get_interface(type_qualified_name); |
71 | std::ostringstream subtype_check_msg; |
72 | TORCH_CHECK( |
73 | type_hint_ptr != nullptr && |
74 | module.value().type()->isSubtypeOfExt( |
75 | *type_hint_ptr, &subtype_check_msg), |
76 | module.value().type()->repr_str(), |
77 | " is not a subtype of the type hint: " , |
78 | type_qualified_name.qualifiedName(), |
79 | ", did you pass a valid interface type?\n" , |
80 | subtype_check_msg.str()); |
81 | return type_hint_ptr; |
82 | } else { |
83 | TORCH_CHECK( |
84 | type_hint.is_none(), |
85 | "type_hint should only be specified when the RRef being created contains a ScriptModule." ); |
86 | } |
87 | |
88 | // Check if value is an instance of a ScriptClass. If not, skip type inference |
89 | // because it will try to script the class that value is in instance of, and |
90 | // this should be avoided. |
91 | py::bool_ can_compile = py::module::import("torch._jit_internal" ) |
92 | .attr("can_compile_class" )(value.get_type()); |
93 | |
94 | if (py::cast<bool>(can_compile)) { |
95 | py::object existing_ty = py::module::import("torch.jit._state" ) |
96 | .attr("_get_script_class" )(value.get_type()); |
97 | |
98 | if (existing_ty.is_none()) { |
99 | return PyObjectType::get(); |
100 | } |
101 | } |
102 | |
103 | // NB: `jit::tryToInferType(..)` infers types including ScriptClass, but |
104 | // excluding ScriptModule. |
105 | jit::InferredType type_inferred = jit::tryToInferType(value); |
106 | if (type_inferred.success()) { |
107 | // If we could infer the type from the pyobject, we create |
108 | // the RRef with the IValue of that type. |
109 | return type_inferred.type(); |
110 | } |
111 | |
112 | // Otherwise it's a pure pyobject, create the RRef |
113 | // that holds an IValue of an pyobject. |
114 | return PyObjectType::get(); |
115 | } |
116 | |
117 | } // namespace |
118 | |
119 | /////////////////////////// PyRRef ////////////////////////////////// |
120 | |
121 | PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) |
122 | : rref_(std::move(rref)), profilingFuture_(c10::nullopt) { |
123 | TORCH_CHECK(rref_, "PyRRef must not wrap nullptr" ); |
124 | C10_LOG_API_USAGE_ONCE("torch.distributed.rref" ); |
125 | } |
126 | |
127 | PyRRef::PyRRef(const py::object& value, const py::object& type_hint) |
128 | : PyRRef([&value, &type_hint]() mutable { |
129 | TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint); |
130 | auto rref = RRefContext::getInstance().createOwnerRRef(elem_type); |
131 | // jit::toIValue takes a py::handle as the first argument, and it calls |
132 | // py::handle.cast<py::object>() to incref of provided value. The |
133 | // returned ivalue will keep the reference alive. |
134 | // NB: the first argument const py::object& value must be kept alive |
135 | // until the following jit::toIValue returns (i.e., incref done). That's |
136 | // why this ctor can only be called while holding GIL. |
137 | IValue ivalue = jit::toIValue(value, elem_type); |
138 | rref->setValue(std::move(ivalue)); |
139 | return rref; |
140 | }()) {} |
141 | |
142 | PyRRef::~PyRRef() { |
143 | if (type_.has_value()) { |
144 | (*type_).dec_ref(); |
145 | // explicitly setting PyObject* to nullptr to prevent py::object's dtor to |
146 | // decref on the PyObject again. |
147 | // See Note [Destructing py::object] in python_ivalue.h |
148 | (*type_).ptr() = nullptr; |
149 | } |
150 | } |
151 | |
152 | c10::intrusive_ptr<JitFuture> PyRRef::getFuture() const { |
153 | // Marking hasValue to false, as this Future is only used for signaling |
154 | // profiler to update profiling result and the profiler does not retrieve |
155 | // any value from it. |
156 | return toPyJitFuture(rref_->getOwnerCreationFuture(), false /* hasValue */); |
157 | } |
158 | |
159 | c10::intrusive_ptr<JitFuture> PyRRef::getProfilingFuture() const { |
160 | TORCH_INTERNAL_ASSERT(profilingFuture_, "Profiling future has not been set!" ); |
161 | return *profilingFuture_; |
162 | } |
163 | |
164 | void PyRRef::setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture) { |
165 | profilingFuture_ = std::move(profilingFuture); |
166 | } |
167 | |
168 | bool PyRRef::isOwner() const { |
169 | return rref_->isOwner(); |
170 | } |
171 | |
172 | bool PyRRef::confirmedByOwner() const { |
173 | return rref_->confirmedByOwner(); |
174 | } |
175 | |
176 | WorkerInfo PyRRef::owner() const { |
177 | return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner()); |
178 | } |
179 | |
180 | std::string PyRRef::ownerName() const { |
181 | return rref_->ownerName(); |
182 | } |
183 | |
184 | py::object PyRRef::toHere(const float timeoutSeconds) const { |
185 | C10_LOG_API_USAGE_ONCE("torch.distributed.rref.to_here" ); |
186 | if (rref_->isOwner()) { |
187 | return localValue(); |
188 | } else { |
189 | // toHere() calls python_rpc_handler which acquires GIL when UserRRef holds |
190 | // a python object |
191 | IValue value = c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere( |
192 | timeoutSeconds); |
193 | |
194 | if (rref_->isPyObj()) { |
195 | // python_rpc_handler deserialization will acquires GIL. |
196 | auto rfr_values = value.toTupleRef().elements().vec(); |
197 | auto& pythonRpcHandler = PythonRpcHandler::getInstance(); |
198 | auto ret = pythonRpcHandler.deserialize( |
199 | SerializedPyObj::fromIValues(std::move(rfr_values))); |
200 | pythonRpcHandler.handleException(ret); |
201 | return ret; |
202 | } else { |
203 | // acquiring GIL as torch::jit::toPyObject creates new py::object |
204 | // without grabbing the GIL. |
205 | pybind11::gil_scoped_acquire ag; |
206 | return torch::jit::toPyObject(std::move(value)); |
207 | } |
208 | } |
209 | } |
210 | |
211 | py::object PyRRef::localValue() const { |
212 | TORCH_CHECK( |
213 | rref_->isOwner(), |
214 | "For " , |
215 | *rref_, |
216 | ", can't call localValue() on user " , |
217 | RRefContext::getInstance().agent()->getWorkerInfo(), |
218 | ". Call it on owner " , |
219 | owner()); |
220 | |
221 | py::object res; |
222 | auto value = |
223 | c10::static_intrusive_pointer_cast<const OwnerRRef>(rref_)->getValue(); |
224 | auto& rpcHandler = PythonRpcHandler::getInstance(); |
225 | { |
226 | // acquiring GIL as torch::jit::toPyObject creates new py::object without |
227 | // grabbing the GIL. |
228 | pybind11::gil_scoped_acquire ag; |
229 | res = torch::jit::toPyObject(std::move(value)); |
230 | rpcHandler.handleExceptionGILHeld(res); |
231 | } |
232 | return res; |
233 | } |
234 | |
235 | std::string PyRRef::str() const { |
236 | if (rref_->isOwner()) { |
237 | return c10::str("OwnerRRef(" , rref_->rrefId(), ")" ); |
238 | } else { |
239 | return c10::str( |
240 | "UserRRef(RRefId = " , |
241 | rref_->rrefId(), |
242 | ", ForkId = " , |
243 | c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId(), |
244 | ")" ); |
245 | } |
246 | } |
247 | |
248 | py::object PyRRef::createRRefProxy( |
249 | const RRefProxyType& type, |
250 | float timeoutSeconds) const { |
251 | auto& pythonRpcHandler = PythonRpcHandler::getInstance(); |
252 | pybind11::gil_scoped_acquire ag; |
253 | auto& functions = pythonRpcHandler.getRRefProxyFunctions(); |
254 | auto& ctor = functions.rrefProxyCtor_; |
255 | switch (type) { |
256 | case RRefProxyType::RPC_SYNC: { |
257 | return ctor(*this, functions.rpcSync_, timeoutSeconds); |
258 | } |
259 | case RRefProxyType::RPC_ASYNC: { |
260 | return ctor(*this, functions.rpcAsync_, timeoutSeconds); |
261 | } |
262 | case RRefProxyType::REMOTE: { |
263 | return ctor(*this, functions.remote_, timeoutSeconds); |
264 | } |
265 | default: { |
266 | TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type " , type); |
267 | } |
268 | } |
269 | } |
270 | |
271 | py::object PyRRef::getRRefType(float timeout, bool blocking) { |
272 | // GIL is not released when calling this function. |
273 | if (!type_.has_value()) { |
274 | pybind11::gil_scoped_release release; |
275 | auto& pythonRpcHandler = PythonRpcHandler::getInstance(); |
276 | auto& typeFuncs = pythonRpcHandler.getRRefTypeFunctions(); |
277 | pybind11::gil_scoped_acquire acquire; |
278 | type_ = isOwner() ? typeFuncs.onOwner_(*this, blocking) |
279 | : typeFuncs.onUser_(*this, timeout, blocking); |
280 | } |
281 | // Returns py::object that can be Python type or future. |
282 | return *type_; |
283 | } |
284 | |
285 | py::tuple PyRRef::pickle() const { |
286 | auto& ctx = RRefContext::getInstance(); |
287 | auto rrefForkData = ctx.prepareChildFork(rref_); |
288 | return toPyTuple(rrefForkData); |
289 | } |
290 | |
291 | PyRRef PyRRef::unpickle(const py::tuple& pyTuple) { |
292 | auto& ctx = RRefContext::getInstance(); |
293 | auto rrefForkData = fromPyTuple(pyTuple); |
294 | TypePtr rrefType = |
295 | PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_); |
296 | c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType); |
297 | ctx.notifyOwnerAndParentOfFork( |
298 | rrefForkData.forkId_, rrefForkData.parent_, rref); |
299 | return PyRRef(std::move(rref)); |
300 | } |
301 | |
302 | c10::IValue PyRRef::toIValue() const { |
303 | // cast to RRefInterface to hold it into IValue |
304 | auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_); |
305 | return IValue(rrefPtr); |
306 | } |
307 | |
308 | void PyRRef::backward(int64_t autogradContextId, bool retainGraph) { |
309 | backward(autogradContextId, retainGraph, rref_); |
310 | } |
311 | |
312 | void PyRRef::backwardOwnerRRef( |
313 | int64_t autogradContextId, |
314 | bool retainGraph, |
315 | IValue value) { |
316 | // If we have a PyObj, retrieve the underlying tensor. |
317 | if (value.isPyObject()) { |
318 | py::gil_scoped_acquire gil; |
319 | py::object obj = torch::jit::toPyObject(value); |
320 | try { |
321 | value = torch::jit::toIValue(obj, c10::TensorType::get()); |
322 | } catch (py::cast_error& e) { |
323 | TORCH_CHECK(false, "RRef should contain a tensor for .backward()" ); |
324 | } |
325 | } |
326 | |
327 | TORCH_CHECK(value.isTensor(), "RRef should contain a tensor for .backward()" ); |
328 | auto root = value.toTensor(); |
329 | |
330 | if (autogradContextId == -1) { |
331 | torch::autograd::backward({root}); |
332 | } else { |
333 | torch::distributed::autograd::backward( |
334 | autogradContextId, {root}, retainGraph); |
335 | } |
336 | } |
337 | |
338 | void PyRRef::backward( |
339 | int64_t autogradContextId, |
340 | bool retainGraph, |
341 | const c10::intrusive_ptr<RRef>& rref) { |
342 | if (rref->isOwner()) { |
343 | backwardOwnerRRef( |
344 | autogradContextId, |
345 | retainGraph, |
346 | c10::static_intrusive_pointer_cast<const OwnerRRef>(rref)->getValue()); |
347 | } else { |
348 | TORCH_CHECK( |
349 | autogradContextId != -1, |
350 | "User RRefs require 'dist_autograd_ctx_id' to be specified" ); |
351 | |
352 | autograd::RRefBackwardReq rrefBackwardReq( |
353 | rref->rrefId(), autogradContextId, retainGraph); |
354 | |
355 | // Invoke distributed backward remotely. |
356 | auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); |
357 | rpcAgent |
358 | ->send( |
359 | rpcAgent->getWorkerInfo(rref->owner()), |
360 | std::move(rrefBackwardReq).toMessage()) |
361 | ->waitAndThrow(); |
362 | } |
363 | } |
364 | |
365 | } // namespace rpc |
366 | } // namespace distributed |
367 | } // namespace torch |
368 | |