1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/autograd/context/context.h> |
4 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h> |
5 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h> |
6 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h> |
7 | |
8 | namespace torch { |
9 | namespace distributed { |
10 | namespace autograd { |
11 | |
12 | // This method is used to attach the 'send' autograd function to the autograd |
13 | // graph when we use RPC. This method creates a new 'send' autograd function |
14 | // and attaches the provided tensors as next_edges to the 'send' function. In |
15 | // addition to this, it also registers the send function in the provided |
16 | // autograd context. Finally, the RPC message is updated with appropriate |
17 | // autograd information for the recipient. |
18 | TORCH_API void addSendRpcBackward( |
19 | const ContextPtr& autogradContext, |
20 | const AutogradMetadata& autogradMetadata, |
21 | std::vector<torch::Tensor>& tensors); |
22 | |
23 | // This method is used to attach the 'recv' autograd function to the autograd |
24 | // graph when we use RPC. This method creates a new 'recv' autograd function |
25 | // and attaches the provided tensors as inputs to the 'recv' function. It |
26 | // creates a new autograd context if needed and registers the 'recv' function |
27 | // with this context. |
28 | // |
29 | // Returns a pointer to the autograd context created. |
30 | TORCH_API ContextPtr addRecvRpcBackward( |
31 | const AutogradMetadata& autogradMetadata, |
32 | std::vector<torch::Tensor>& tensors, |
33 | rpc::worker_id_t fromWorkerId, |
34 | const rpc::DeviceMap& deviceMap); |
35 | |
36 | // This method is a wrapper utility used internally to wrap autograd info |
37 | // and attach autograd function for each type of rpc call if it has valid |
38 | // context and tensors require grads or forceGradRecording is true, in this |
39 | // case, return RpcWithAutograd message; otherwise return original rpc message. |
40 | // NB: forceGradRecording is useful when the request does not contain any tensor |
41 | // but the corresponding response does. |
42 | TORCH_API c10::intrusive_ptr<rpc::Message> getMessageWithAutograd( |
43 | const rpc::worker_id_t dstId, |
44 | c10::intrusive_ptr<rpc::Message> wrappedRpcMsg, |
45 | rpc::MessageType msgType, |
46 | bool forceGradRecording = false, |
47 | const rpc::DeviceMap& deviceMap = {}); |
48 | |
49 | // Send message after autograd checking |
50 | TORCH_API c10::intrusive_ptr<c10::ivalue::Future> sendMessageWithAutograd( |
51 | rpc::RpcAgent& agent, |
52 | const rpc::WorkerInfo& dst, |
53 | c10::intrusive_ptr<rpc::Message> wrappedRpcMsg, |
54 | bool forceGradRecording = false, |
55 | const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout, |
56 | bool forceDisableProfiling = false); |
57 | |
58 | } // namespace autograd |
59 | } // namespace distributed |
60 | } // namespace torch |
61 | |