1 | #include <torch/csrc/distributed/rpc/python_rpc_handler.h> |
2 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
3 | #include <torch/csrc/jit/python/pybind_utils.h> |
4 | #include <torch/csrc/utils/python_compat.h> |
5 | |
6 | namespace torch { |
7 | namespace distributed { |
8 | namespace rpc { |
9 | |
10 | namespace { |
11 | |
12 | constexpr auto kInternalModule = "torch.distributed.rpc.internal" ; |
13 | |
14 | // A macro that grabs the GIL, profiling the acquisition time. The average GIL |
15 | // acquisition time will be recorded in RpcAgent's getMetrics(). |
16 | #define PROFILE_GIL_SCOPED_ACQUIRE \ |
17 | std::chrono::time_point<std::chrono::high_resolution_clock> startTime; \ |
18 | auto shouldProfileGIL = \ |
19 | RpcAgent::getCurrentRpcAgent()->isGILProfilingEnabled(); \ |
20 | if (shouldProfileGIL) { \ |
21 | startTime = std::chrono::high_resolution_clock::now(); \ |
22 | } \ |
23 | pybind11::gil_scoped_acquire ag; \ |
24 | if (shouldProfileGIL) { \ |
25 | auto dur = std::chrono::duration_cast<std::chrono::microseconds>( \ |
26 | std::chrono::high_resolution_clock::now() - startTime); \ |
27 | RpcAgent::getCurrentRpcAgent()->addGilWaitTime(dur); \ |
28 | } // NOLINT |
29 | |
30 | // PythonTypeResolver that inherits from Script::Resolver to |
31 | // support resolving types together with ScriptTypeParser. |
32 | struct PythonTypeResolver : public jit::Resolver { |
33 | std::shared_ptr<jit::SugaredValue> resolveValue( |
34 | const std::string& /* unused */, |
35 | torch::jit::GraphFunction& /* unused */, |
36 | const jit::SourceRange& /* unused */) override { |
37 | TORCH_INTERNAL_ASSERT( |
38 | false, "RPC Type resolver does not need to resolve value" ); |
39 | } |
40 | |
41 | TypePtr resolveType( |
42 | const std::string& name, |
43 | const jit::SourceRange& /* unused */) override { |
44 | if (name == "PyObject" ) { |
45 | return PyObjectType::get(); |
46 | } |
47 | return PythonRpcHandler::getInstance().jitCompilationUnit()->get_type(name); |
48 | } |
49 | }; |
50 | |
51 | py::object getFunction(const py::object& module, const char* name) { |
52 | py::object fn = module.attr(name); |
53 | TORCH_CHECK( |
54 | py::isinstance<py::function>(fn), |
55 | "attribute " , |
56 | name, |
57 | " is not a function" ); |
58 | return fn; |
59 | } |
60 | |
61 | void cleanupPyObj(py::object& obj) { |
62 | obj.dec_ref(); |
63 | // explicitly setting PyObject* to nullptr to prevent py::object's dtor to |
64 | // decref on the PyObject again. |
65 | // See Note [Destructing py::object] in python_ivalue.h |
66 | obj.ptr() = nullptr; |
67 | } |
68 | |
69 | } // namespace |
70 | |
71 | void PythonRpcHandler::init() { |
72 | std::lock_guard<std::mutex> guard(init_lock_); |
73 | if (!initialized_) { |
74 | PROFILE_GIL_SCOPED_ACQUIRE; |
75 | py::object rpcInternal = py::module::import(kInternalModule); |
76 | py::object rpcApi = py::module::import("torch.distributed.rpc.api" ); |
77 | py::object rrefProxy = |
78 | py::module::import("torch.distributed.rpc.rref_proxy" ); |
79 | |
80 | pyRunFunction_ = getFunction(rpcInternal, "_run_function" ); |
81 | pySerialize_ = getFunction(rpcInternal, "serialize" ); |
82 | pyDeserialize_ = getFunction(rpcInternal, "deserialize" ); |
83 | pyHandleException_ = getFunction(rpcInternal, "_handle_exception" ); |
84 | |
85 | rrefTypeFunctions_.onOwner_ = getFunction(rpcApi, "_rref_typeof_on_owner" ); |
86 | rrefTypeFunctions_.onUser_ = getFunction(rpcApi, "_rref_typeof_on_user" ); |
87 | |
88 | rrefProxyFunctions_.rpcSync_ = getFunction(rpcApi, "rpc_sync" ); |
89 | rrefProxyFunctions_.rpcAsync_ = getFunction(rpcApi, "rpc_async" ); |
90 | rrefProxyFunctions_.remote_ = getFunction(rpcApi, "remote" ); |
91 | rrefProxyFunctions_.rrefProxyCtor_ = getFunction(rrefProxy, "RRefProxy" ); |
92 | |
93 | jitCompilationUnit_ = torch::jit::get_python_cu(); |
94 | typeParser_ = std::make_shared<jit::ScriptTypeParser>( |
95 | std::make_shared<PythonTypeResolver>()); |
96 | initialized_ = true; |
97 | } |
98 | } |
99 | |
100 | PythonRpcHandler::PythonRpcHandler() : initialized_(false) {} |
101 | |
102 | void PythonRpcHandler::cleanup() { |
103 | std::lock_guard<std::mutex> guard(init_lock_); |
104 | PROFILE_GIL_SCOPED_ACQUIRE; |
105 | cleanupPyObj(pyRunFunction_); |
106 | cleanupPyObj(pySerialize_); |
107 | cleanupPyObj(pyDeserialize_); |
108 | cleanupPyObj(pyHandleException_); |
109 | |
110 | cleanupPyObj(rrefProxyFunctions_.rpcSync_); |
111 | cleanupPyObj(rrefProxyFunctions_.rpcAsync_); |
112 | cleanupPyObj(rrefProxyFunctions_.remote_); |
113 | cleanupPyObj(rrefProxyFunctions_.rrefProxyCtor_); |
114 | |
115 | jitCompilationUnit_ = nullptr; |
116 | typeParser_ = nullptr; |
117 | initialized_ = false; |
118 | } |
119 | |
120 | PythonRpcHandler& PythonRpcHandler::getInstance() { |
121 | // A thread could hold GIL when calling PythonRpcHandler::getInstance(), |
122 | // meantime another thread could have been doing static data |
123 | // initialization by calling `new PythonRpcHandler()`, inside of which GIL is |
124 | // also required. Static data initialization is thread-safe, so the thread |
125 | // holding the GIL will wait for the other thread to finish static data |
126 | // initializating before going forward. Because the initialization can't |
127 | // proceed without GIL, there is a deadlock. We ask the calling thread to |
128 | // release GIL to avoid this situation. |
129 | TORCH_INTERNAL_ASSERT(!PyGILState_Check()); |
130 | // Leaky singleton to avoid module destructor race. |
131 | static PythonRpcHandler* handler = new PythonRpcHandler(); |
132 | handler->init(); |
133 | return *handler; |
134 | } |
135 | |
136 | std::shared_ptr<torch::jit::CompilationUnit> PythonRpcHandler:: |
137 | jitCompilationUnit() { |
138 | return jitCompilationUnit_; |
139 | } |
140 | |
141 | py::object PythonRpcHandler::runPythonUdf(const py::object& pythonUdf) { |
142 | PROFILE_GIL_SCOPED_ACQUIRE; |
143 | // Throw a descriptive error message if pyRunFunction_ is already cleaned up. |
144 | TORCH_INTERNAL_ASSERT( |
145 | !pyRunFunction_.is_none(), |
146 | "Cannot run python UDF since pyRunFunction_ is None. Check if python RPC " |
147 | "handler is already cleaned up." ); |
148 | return pyRunFunction_(pythonUdf); |
149 | } |
150 | |
151 | SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) { |
152 | PROFILE_GIL_SCOPED_ACQUIRE; |
153 | py::tuple t = pySerialize_(obj); |
154 | return SerializedPyObj( |
155 | t[0].cast<std::string>(), t[1].cast<std::vector<torch::Tensor>>()); |
156 | } |
157 | |
158 | py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) { |
159 | PROFILE_GIL_SCOPED_ACQUIRE; |
160 | // NB: pyDeserialize_ can return an AttributeError if the deserialize() Python |
161 | // function fails. Functions consuming the result needs to handle such error |
162 | // properly. |
163 | return pyDeserialize_( |
164 | py::bytes(serializedObj.payload_), serializedObj.tensors_); |
165 | } |
166 | |
167 | void PythonRpcHandler::handleException(const py::object& obj) { |
168 | PROFILE_GIL_SCOPED_ACQUIRE; |
169 | pyHandleException_(obj); |
170 | } |
171 | |
172 | void PythonRpcHandler::handleExceptionGILHeld(const py::object& obj) { |
173 | TORCH_CHECK(PyGILState_Check(), "GIL should be held" ); |
174 | pyHandleException_(obj); |
175 | } |
176 | |
177 | bool PythonRpcHandler::isRemoteException(const py::object& obj) { |
178 | PROFILE_GIL_SCOPED_ACQUIRE; |
179 | auto type = obj.get_type(); |
180 | auto moduleName = type.attr("__module__" ).cast<std::string>(); |
181 | auto qualName = type.attr("__qualname__" ).cast<std::string>(); |
182 | return moduleName == kInternalModule && qualName == "RemoteException" ; |
183 | } |
184 | |
185 | TypePtr PythonRpcHandler::parseTypeFromStr(const std::string& type_str) { |
186 | return typeParser_->parseType(type_str); |
187 | } |
188 | |
189 | const PythonRpcHandler::RRefProxyFunctions& PythonRpcHandler:: |
190 | getRRefProxyFunctions() const { |
191 | return rrefProxyFunctions_; |
192 | } |
193 | |
194 | const PythonRpcHandler::RRefTypeFunctions& PythonRpcHandler:: |
195 | getRRefTypeFunctions() const { |
196 | return rrefTypeFunctions_; |
197 | } |
198 | |
199 | } // namespace rpc |
200 | } // namespace distributed |
201 | } // namespace torch |
202 | |