1#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
2#include <torch/csrc/distributed/rpc/rpc_agent.h>
3#include <torch/csrc/jit/python/pybind_utils.h>
4#include <torch/csrc/utils/python_compat.h>
5
6namespace torch {
7namespace distributed {
8namespace rpc {
9
10namespace {
11
12constexpr auto kInternalModule = "torch.distributed.rpc.internal";
13
14// A macro that grabs the GIL, profiling the acquisition time. The average GIL
15// acquisition time will be recorded in RpcAgent's getMetrics().
16#define PROFILE_GIL_SCOPED_ACQUIRE \
17 std::chrono::time_point<std::chrono::high_resolution_clock> startTime; \
18 auto shouldProfileGIL = \
19 RpcAgent::getCurrentRpcAgent()->isGILProfilingEnabled(); \
20 if (shouldProfileGIL) { \
21 startTime = std::chrono::high_resolution_clock::now(); \
22 } \
23 pybind11::gil_scoped_acquire ag; \
24 if (shouldProfileGIL) { \
25 auto dur = std::chrono::duration_cast<std::chrono::microseconds>( \
26 std::chrono::high_resolution_clock::now() - startTime); \
27 RpcAgent::getCurrentRpcAgent()->addGilWaitTime(dur); \
28 } // NOLINT
29
30// PythonTypeResolver that inherits from Script::Resolver to
31// support resolving types together with ScriptTypeParser.
32struct PythonTypeResolver : public jit::Resolver {
33 std::shared_ptr<jit::SugaredValue> resolveValue(
34 const std::string& /* unused */,
35 torch::jit::GraphFunction& /* unused */,
36 const jit::SourceRange& /* unused */) override {
37 TORCH_INTERNAL_ASSERT(
38 false, "RPC Type resolver does not need to resolve value");
39 }
40
41 TypePtr resolveType(
42 const std::string& name,
43 const jit::SourceRange& /* unused */) override {
44 if (name == "PyObject") {
45 return PyObjectType::get();
46 }
47 return PythonRpcHandler::getInstance().jitCompilationUnit()->get_type(name);
48 }
49};
50
51py::object getFunction(const py::object& module, const char* name) {
52 py::object fn = module.attr(name);
53 TORCH_CHECK(
54 py::isinstance<py::function>(fn),
55 "attribute ",
56 name,
57 " is not a function");
58 return fn;
59}
60
61void cleanupPyObj(py::object& obj) {
62 obj.dec_ref();
63 // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
64 // decref on the PyObject again.
65 // See Note [Destructing py::object] in python_ivalue.h
66 obj.ptr() = nullptr;
67}
68
69} // namespace
70
71void PythonRpcHandler::init() {
72 std::lock_guard<std::mutex> guard(init_lock_);
73 if (!initialized_) {
74 PROFILE_GIL_SCOPED_ACQUIRE;
75 py::object rpcInternal = py::module::import(kInternalModule);
76 py::object rpcApi = py::module::import("torch.distributed.rpc.api");
77 py::object rrefProxy =
78 py::module::import("torch.distributed.rpc.rref_proxy");
79
80 pyRunFunction_ = getFunction(rpcInternal, "_run_function");
81 pySerialize_ = getFunction(rpcInternal, "serialize");
82 pyDeserialize_ = getFunction(rpcInternal, "deserialize");
83 pyHandleException_ = getFunction(rpcInternal, "_handle_exception");
84
85 rrefTypeFunctions_.onOwner_ = getFunction(rpcApi, "_rref_typeof_on_owner");
86 rrefTypeFunctions_.onUser_ = getFunction(rpcApi, "_rref_typeof_on_user");
87
88 rrefProxyFunctions_.rpcSync_ = getFunction(rpcApi, "rpc_sync");
89 rrefProxyFunctions_.rpcAsync_ = getFunction(rpcApi, "rpc_async");
90 rrefProxyFunctions_.remote_ = getFunction(rpcApi, "remote");
91 rrefProxyFunctions_.rrefProxyCtor_ = getFunction(rrefProxy, "RRefProxy");
92
93 jitCompilationUnit_ = torch::jit::get_python_cu();
94 typeParser_ = std::make_shared<jit::ScriptTypeParser>(
95 std::make_shared<PythonTypeResolver>());
96 initialized_ = true;
97 }
98}
99
100PythonRpcHandler::PythonRpcHandler() : initialized_(false) {}
101
102void PythonRpcHandler::cleanup() {
103 std::lock_guard<std::mutex> guard(init_lock_);
104 PROFILE_GIL_SCOPED_ACQUIRE;
105 cleanupPyObj(pyRunFunction_);
106 cleanupPyObj(pySerialize_);
107 cleanupPyObj(pyDeserialize_);
108 cleanupPyObj(pyHandleException_);
109
110 cleanupPyObj(rrefProxyFunctions_.rpcSync_);
111 cleanupPyObj(rrefProxyFunctions_.rpcAsync_);
112 cleanupPyObj(rrefProxyFunctions_.remote_);
113 cleanupPyObj(rrefProxyFunctions_.rrefProxyCtor_);
114
115 jitCompilationUnit_ = nullptr;
116 typeParser_ = nullptr;
117 initialized_ = false;
118}
119
120PythonRpcHandler& PythonRpcHandler::getInstance() {
121 // A thread could hold GIL when calling PythonRpcHandler::getInstance(),
122 // meantime another thread could have been doing static data
123 // initialization by calling `new PythonRpcHandler()`, inside of which GIL is
124 // also required. Static data initialization is thread-safe, so the thread
125 // holding the GIL will wait for the other thread to finish static data
126 // initializating before going forward. Because the initialization can't
127 // proceed without GIL, there is a deadlock. We ask the calling thread to
128 // release GIL to avoid this situation.
129 TORCH_INTERNAL_ASSERT(!PyGILState_Check());
130 // Leaky singleton to avoid module destructor race.
131 static PythonRpcHandler* handler = new PythonRpcHandler();
132 handler->init();
133 return *handler;
134}
135
136std::shared_ptr<torch::jit::CompilationUnit> PythonRpcHandler::
137 jitCompilationUnit() {
138 return jitCompilationUnit_;
139}
140
141py::object PythonRpcHandler::runPythonUdf(const py::object& pythonUdf) {
142 PROFILE_GIL_SCOPED_ACQUIRE;
143 // Throw a descriptive error message if pyRunFunction_ is already cleaned up.
144 TORCH_INTERNAL_ASSERT(
145 !pyRunFunction_.is_none(),
146 "Cannot run python UDF since pyRunFunction_ is None. Check if python RPC "
147 "handler is already cleaned up.");
148 return pyRunFunction_(pythonUdf);
149}
150
151SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
152 PROFILE_GIL_SCOPED_ACQUIRE;
153 py::tuple t = pySerialize_(obj);
154 return SerializedPyObj(
155 t[0].cast<std::string>(), t[1].cast<std::vector<torch::Tensor>>());
156}
157
158py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
159 PROFILE_GIL_SCOPED_ACQUIRE;
160 // NB: pyDeserialize_ can return an AttributeError if the deserialize() Python
161 // function fails. Functions consuming the result needs to handle such error
162 // properly.
163 return pyDeserialize_(
164 py::bytes(serializedObj.payload_), serializedObj.tensors_);
165}
166
167void PythonRpcHandler::handleException(const py::object& obj) {
168 PROFILE_GIL_SCOPED_ACQUIRE;
169 pyHandleException_(obj);
170}
171
172void PythonRpcHandler::handleExceptionGILHeld(const py::object& obj) {
173 TORCH_CHECK(PyGILState_Check(), "GIL should be held");
174 pyHandleException_(obj);
175}
176
177bool PythonRpcHandler::isRemoteException(const py::object& obj) {
178 PROFILE_GIL_SCOPED_ACQUIRE;
179 auto type = obj.get_type();
180 auto moduleName = type.attr("__module__").cast<std::string>();
181 auto qualName = type.attr("__qualname__").cast<std::string>();
182 return moduleName == kInternalModule && qualName == "RemoteException";
183}
184
185TypePtr PythonRpcHandler::parseTypeFromStr(const std::string& type_str) {
186 return typeParser_->parseType(type_str);
187}
188
189const PythonRpcHandler::RRefProxyFunctions& PythonRpcHandler::
190 getRRefProxyFunctions() const {
191 return rrefProxyFunctions_;
192}
193
194const PythonRpcHandler::RRefTypeFunctions& PythonRpcHandler::
195 getRRefTypeFunctions() const {
196 return rrefTypeFunctions_;
197}
198
199} // namespace rpc
200} // namespace distributed
201} // namespace torch
202