1#pragma once
2
3#include <torch/csrc/distributed/rpc/py_rref.h>
4#include <torch/csrc/distributed/rpc/rpc_agent.h>
5#include <torch/csrc/jit/python/pybind_utils.h>
6#include <torch/csrc/utils/pybind.h>
7
8namespace torch {
9namespace distributed {
10namespace rpc {
11
12// Converts an internal ivalue::Future of Message into a user-facing
13// ivalue::Future of py::object type by creating a new ivalue::Future and call
14// its markCompleted as a callback in the given ivalue::Future.
15// If hasValue is true, the Message will be converted into a py::object and then
16// wrap it with an IValue. If hasValue is false, this ivalue::Future is only
17// used for signaling and launching callbacks. In this case, the message will be
18// discarded and then set the ivalue::Future using an empty IValue or the given
19// FutureError if there is an error.
20c10::intrusive_ptr<JitFuture> toPyJitFuture(
21 const c10::intrusive_ptr<JitFuture>& messageJitFuture,
22 bool hasValue = true);
23
24c10::intrusive_ptr<JitFuture> pyRpcBuiltin(
25 const WorkerInfo& dst,
26 const std::string& opName,
27 const py::args& args,
28 const py::kwargs& kwargs,
29 const float rpcTimeoutSeconds);
30
31c10::intrusive_ptr<JitFuture> pyRpcPythonUdf(
32 const WorkerInfo& dst,
33 std::string& pickledPythonUDF,
34 std::vector<torch::Tensor>& tensors,
35 const float rpcTimeoutSeconds,
36 const bool isAsyncExecution);
37
38c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
39 const std::string& dstWorkerName,
40 const std::string& qualifiedNameStr,
41 const py::tuple& argsTuple,
42 const py::dict& kwargsDict,
43 const float rpcTimeoutSeconds,
44 const bool isAsyncExecution);
45
46PyRRef pyRemoteBuiltin(
47 const WorkerInfo& dst,
48 const std::string& opName,
49 const float rpcTimeoutSeconds,
50 const py::args& args,
51 const py::kwargs& kwargs);
52
53PyRRef pyRemotePythonUdf(
54 const WorkerInfo& dst,
55 std::string& pickledPythonUDF,
56 std::vector<torch::Tensor>& tensors,
57 const float rpcTimeoutSeconds,
58 const bool isAsyncExecution);
59
60PyRRef pyRemoteTorchscript(
61 const std::string& dstWorkerName,
62 const std::string& qualifiedNameStr,
63 const float rpcTimeoutSeconds,
64 const bool isAsyncExecution,
65 const py::args& args,
66 const py::kwargs& kwargs);
67
68} // namespace rpc
69} // namespace distributed
70} // namespace torch
71