1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/rpc/message.h> |
4 | #include <torch/csrc/distributed/rpc/types.h> |
5 | #include <torch/csrc/jit/frontend/script_type_parser.h> |
6 | #include <torch/csrc/utils/pybind.h> |
7 | |
8 | namespace torch { |
9 | namespace distributed { |
10 | namespace rpc { |
11 | |
12 | // Singleton class provides interface to execute python UDF remote call |
13 | // and deserialize the returned results by running python function |
14 | // in internal_rpc_utilities. |
15 | // The singleton object is constructed at first when RPC agent is |
16 | // constructed, where the python function in |
17 | // torch/distributed/internal_rpc_utils.py are imported only once. |
18 | class PYBIND11_EXPORT PythonRpcHandler { |
19 | public: |
20 | struct RRefProxyFunctions { |
21 | py::object rrefProxyCtor_; |
22 | py::object rpcSync_; |
23 | py::object rpcAsync_; |
24 | py::object remote_; |
25 | }; |
26 | |
27 | struct RRefTypeFunctions { |
28 | py::object onOwner_; |
29 | py::object onUser_; |
30 | }; |
31 | |
32 | static PythonRpcHandler& getInstance(); |
33 | |
34 | // Run a pickled Python UDF and return the result py::object |
35 | py::object runPythonUdf(const py::object& pythonUdf); |
36 | |
37 | // Serialized a py::object into a string |
38 | SerializedPyObj serialize(const py::object& obj); |
39 | |
40 | // Deserialize a string into a py::object |
41 | py::object deserialize(const SerializedPyObj& serializedObj); |
42 | |
43 | // Check if obj is RemoteException, then throw it |
44 | void handleException(const py::object& obj); |
45 | // Alternative if the caller is already holding the GIL. |
46 | void handleExceptionGILHeld(const py::object& obj); |
47 | // Check if obj is an RemoteException instance. |
48 | bool isRemoteException(const py::object& obj); |
49 | |
50 | // Explicitly clean up py::objects to avoid segment faults when |
51 | // py::objects with CPython are cleaned up later at program exit |
52 | // See similar issues reported https://github.com/pybind/pybind11/issues/1598 |
53 | // and https://github.com/pybind/pybind11/issues/1493 |
54 | // Our local tests also caught this segment faults if py::objects are cleaned |
55 | // up at program exit. The explanation is: CPython cleans up most critical |
56 | // utilities before cleaning up PythonRpcHandler singleton, so when |
57 | // PythonRpcHandler singleton cleans up py::objects and call dec_ref(), it |
58 | // will crash. |
59 | // The solution is to clean up py::objects earlier when Rpc agent join(). |
60 | // Be note that py::objects can not be cleaned up when Rpc agent is destroyed |
61 | // as well, as Rpc agent is global variable and it will have same issue as |
62 | // PythonRpcHandler. |
63 | void cleanup(); |
64 | |
65 | std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit(); |
66 | |
67 | // Parse the string to recover the jit_type, this is used for RRef python |
68 | // pickling/unpickling type recovery. The type string inference rule is as |
69 | // follows: |
70 | // 1. first try to parse if this is primitive types. |
71 | // i.e. TensorType, IntType, PyObjectType, etc. |
72 | // 2. if not primitive type, we query the python_cu to see if it is a |
73 | // class type or interface type registered in python |
74 | // We use a ScriptTypeParser instance with custom PythonTypeResolver |
75 | // to resolve types according to the above rules. |
76 | TypePtr parseTypeFromStr(const std::string& typeStr); |
77 | |
78 | // Return a set of Python functions for RRef helpers. |
79 | const RRefProxyFunctions& getRRefProxyFunctions() const; |
80 | |
81 | // Return a set of Python functions to retrieve the type of the object |
82 | // referenced by a given RRef. |
83 | const RRefTypeFunctions& getRRefTypeFunctions() const; |
84 | |
85 | PythonRpcHandler(const PythonRpcHandler&) = delete; |
86 | PythonRpcHandler& operator=(const PythonRpcHandler&) = delete; |
87 | PythonRpcHandler(PythonRpcHandler&&) = delete; |
88 | PythonRpcHandler& operator=(PythonRpcHandler&&) = delete; |
89 | |
90 | private: |
91 | void init(); |
92 | PythonRpcHandler(); |
93 | ~PythonRpcHandler() = default; |
94 | |
95 | // Ref to `torch.distributed.rpc.internal._run_function`. |
96 | py::object pyRunFunction_; |
97 | |
98 | // Ref to `torch.distributed.rpc.internal.serialize`. |
99 | py::object pySerialize_; |
100 | |
101 | // Ref to `torch.distributed.rpc.internal.deserialize`. |
102 | py::object pyDeserialize_; |
103 | |
104 | // Ref to 'torch.distributed.rpc.internal._handle_exception' |
105 | py::object pyHandleException_; |
106 | |
107 | // Python functions for RRef proxy |
108 | RRefProxyFunctions rrefProxyFunctions_; |
109 | |
110 | // Ref to 'torch.distributed.rpc.api._rref_typeof_on_' |
111 | RRefTypeFunctions rrefTypeFunctions_; |
112 | |
113 | // Shared ptr to python compilation unit in jit, it is constructed in python |
114 | // side (see _python_cu = torch._C.CompilationUnit() in jit/__init__.py) |
115 | // and imported in C++ (see get_python_cu() in |
116 | // csrc/jit/python/pybind_utils.h). We import the compilation unit here only |
117 | // once for less cost and thread safety. |
118 | std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_; |
119 | |
120 | // jit type parser to parse type_str back to TypePtr for RRef type |
121 | // recovery when pickling and unpickling RRef |
122 | std::shared_ptr<jit::ScriptTypeParser> typeParser_; |
123 | |
124 | // Indicates whether or not we have properly initialized the handler. |
125 | bool initialized_; |
126 | |
127 | // Lock to protect initialization. |
128 | std::mutex init_lock_; |
129 | }; |
130 | |
131 | } // namespace rpc |
132 | } // namespace distributed |
133 | } // namespace torch |
134 | |