1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | #include <torch/csrc/autograd/profiler.h> |
5 | #include <torch/csrc/distributed/autograd/utils.h> |
6 | #include <torch/csrc/distributed/rpc/rref_context.h> |
7 | #include <torch/csrc/distributed/rpc/script_remote_call.h> |
8 | |
9 | namespace torch { |
10 | namespace distributed { |
11 | namespace rpc { |
12 | |
13 | // This function sends an rpc call to run torchscript function, currently the |
14 | // torchscript function could only be a user defined python function with |
15 | // "@torch.jit.script" annotation. The torchscript function could not be |
16 | // a class constructor, class method, instance method or a script module. |
17 | // dst: destination worker name |
18 | // qualifiedName: torchscript function qualified name string like |
19 | // "moduleName::torchscriptFunctionName", e.g, |
20 | // "dist_autograd_test::my_py_add" |
21 | // stack: a bag of IValue args passed to torchscriptFunctionName |
22 | // It returns c10::intrusive_ptr<ivalue::Future> |
23 | c10::intrusive_ptr<c10::ivalue::Future> TORCH_API rpcTorchscript( |
24 | const std::string& dstWorkerName, |
25 | const c10::QualifiedName& qualifiedName, |
26 | const c10::FunctionSchema& functionSchema, |
27 | std::vector<c10::IValue>& stack, |
28 | const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout, |
29 | const bool isAsyncExecution = false); |
30 | |
31 | c10::intrusive_ptr<RRef> TORCH_API remoteTorchscript( |
32 | const std::string& dstWorkerName, |
33 | const c10::QualifiedName& qualifiedName, |
34 | const c10::FunctionSchema& functionSchema, |
35 | std::vector<c10::IValue>& stack, |
36 | const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout, |
37 | const bool isAsyncExecution = false); |
38 | |
39 | } // namespace rpc |
40 | } // namespace distributed |
41 | } // namespace torch |
42 | |