1#pragma once
2
3#include <torch/csrc/distributed/rpc/message.h>
4#include <torch/csrc/distributed/rpc/types.h>
5#include <torch/csrc/jit/frontend/script_type_parser.h>
6#include <torch/csrc/utils/pybind.h>
7
8namespace torch {
9namespace distributed {
10namespace rpc {
11
12// Singleton class provides interface to execute python UDF remote call
13// and deserialize the returned results by running python function
14// in internal_rpc_utilities.
15// The singleton object is constructed at first when RPC agent is
16// constructed, where the python function in
17// torch/distributed/internal_rpc_utils.py are imported only once.
18class PYBIND11_EXPORT PythonRpcHandler {
19 public:
20 struct RRefProxyFunctions {
21 py::object rrefProxyCtor_;
22 py::object rpcSync_;
23 py::object rpcAsync_;
24 py::object remote_;
25 };
26
27 struct RRefTypeFunctions {
28 py::object onOwner_;
29 py::object onUser_;
30 };
31
32 static PythonRpcHandler& getInstance();
33
34 // Run a pickled Python UDF and return the result py::object
35 py::object runPythonUdf(const py::object& pythonUdf);
36
37 // Serialized a py::object into a string
38 SerializedPyObj serialize(const py::object& obj);
39
40 // Deserialize a string into a py::object
41 py::object deserialize(const SerializedPyObj& serializedObj);
42
43 // Check if obj is RemoteException, then throw it
44 void handleException(const py::object& obj);
45 // Alternative if the caller is already holding the GIL.
46 void handleExceptionGILHeld(const py::object& obj);
47 // Check if obj is an RemoteException instance.
48 bool isRemoteException(const py::object& obj);
49
50 // Explicitly clean up py::objects to avoid segment faults when
51 // py::objects with CPython are cleaned up later at program exit
52 // See similar issues reported https://github.com/pybind/pybind11/issues/1598
53 // and https://github.com/pybind/pybind11/issues/1493
54 // Our local tests also caught this segment faults if py::objects are cleaned
55 // up at program exit. The explanation is: CPython cleans up most critical
56 // utilities before cleaning up PythonRpcHandler singleton, so when
57 // PythonRpcHandler singleton cleans up py::objects and call dec_ref(), it
58 // will crash.
59 // The solution is to clean up py::objects earlier when Rpc agent join().
60 // Be note that py::objects can not be cleaned up when Rpc agent is destroyed
61 // as well, as Rpc agent is global variable and it will have same issue as
62 // PythonRpcHandler.
63 void cleanup();
64
65 std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit();
66
67 // Parse the string to recover the jit_type, this is used for RRef python
68 // pickling/unpickling type recovery. The type string inference rule is as
69 // follows:
70 // 1. first try to parse if this is primitive types.
71 // i.e. TensorType, IntType, PyObjectType, etc.
72 // 2. if not primitive type, we query the python_cu to see if it is a
73 // class type or interface type registered in python
74 // We use a ScriptTypeParser instance with custom PythonTypeResolver
75 // to resolve types according to the above rules.
76 TypePtr parseTypeFromStr(const std::string& typeStr);
77
78 // Return a set of Python functions for RRef helpers.
79 const RRefProxyFunctions& getRRefProxyFunctions() const;
80
81 // Return a set of Python functions to retrieve the type of the object
82 // referenced by a given RRef.
83 const RRefTypeFunctions& getRRefTypeFunctions() const;
84
85 PythonRpcHandler(const PythonRpcHandler&) = delete;
86 PythonRpcHandler& operator=(const PythonRpcHandler&) = delete;
87 PythonRpcHandler(PythonRpcHandler&&) = delete;
88 PythonRpcHandler& operator=(PythonRpcHandler&&) = delete;
89
90 private:
91 void init();
92 PythonRpcHandler();
93 ~PythonRpcHandler() = default;
94
95 // Ref to `torch.distributed.rpc.internal._run_function`.
96 py::object pyRunFunction_;
97
98 // Ref to `torch.distributed.rpc.internal.serialize`.
99 py::object pySerialize_;
100
101 // Ref to `torch.distributed.rpc.internal.deserialize`.
102 py::object pyDeserialize_;
103
104 // Ref to 'torch.distributed.rpc.internal._handle_exception'
105 py::object pyHandleException_;
106
107 // Python functions for RRef proxy
108 RRefProxyFunctions rrefProxyFunctions_;
109
110 // Ref to 'torch.distributed.rpc.api._rref_typeof_on_'
111 RRefTypeFunctions rrefTypeFunctions_;
112
113 // Shared ptr to python compilation unit in jit, it is constructed in python
114 // side (see _python_cu = torch._C.CompilationUnit() in jit/__init__.py)
115 // and imported in C++ (see get_python_cu() in
116 // csrc/jit/python/pybind_utils.h). We import the compilation unit here only
117 // once for less cost and thread safety.
118 std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_;
119
120 // jit type parser to parse type_str back to TypePtr for RRef type
121 // recovery when pickling and unpickling RRef
122 std::shared_ptr<jit::ScriptTypeParser> typeParser_;
123
124 // Indicates whether or not we have properly initialized the handler.
125 bool initialized_;
126
127 // Lock to protect initialization.
128 std::mutex init_lock_;
129};
130
131} // namespace rpc
132} // namespace distributed
133} // namespace torch
134