1#pragma once
2
3#include <torch/csrc/distributed/autograd/context/context.h>
4#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
5#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
6#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
7
8namespace torch {
9namespace distributed {
10namespace autograd {
11
12// This method is used to attach the 'send' autograd function to the autograd
13// graph when we use RPC. This method creates a new 'send' autograd function
14// and attaches the provided tensors as next_edges to the 'send' function. In
15// addition to this, it also registers the send function in the provided
16// autograd context. Finally, the RPC message is updated with appropriate
17// autograd information for the recipient.
18TORCH_API void addSendRpcBackward(
19 const ContextPtr& autogradContext,
20 const AutogradMetadata& autogradMetadata,
21 std::vector<torch::Tensor>& tensors);
22
23// This method is used to attach the 'recv' autograd function to the autograd
24// graph when we use RPC. This method creates a new 'recv' autograd function
25// and attaches the provided tensors as inputs to the 'recv' function. It
26// creates a new autograd context if needed and registers the 'recv' function
27// with this context.
28//
29// Returns a pointer to the autograd context created.
30TORCH_API ContextPtr addRecvRpcBackward(
31 const AutogradMetadata& autogradMetadata,
32 std::vector<torch::Tensor>& tensors,
33 rpc::worker_id_t fromWorkerId,
34 const rpc::DeviceMap& deviceMap);
35
36// This method is a wrapper utility used internally to wrap autograd info
37// and attach autograd function for each type of rpc call if it has valid
38// context and tensors require grads or forceGradRecording is true, in this
39// case, return RpcWithAutograd message; otherwise return original rpc message.
40// NB: forceGradRecording is useful when the request does not contain any tensor
41// but the corresponding response does.
42TORCH_API c10::intrusive_ptr<rpc::Message> getMessageWithAutograd(
43 const rpc::worker_id_t dstId,
44 c10::intrusive_ptr<rpc::Message> wrappedRpcMsg,
45 rpc::MessageType msgType,
46 bool forceGradRecording = false,
47 const rpc::DeviceMap& deviceMap = {});
48
49// Send message after autograd checking
50TORCH_API c10::intrusive_ptr<c10::ivalue::Future> sendMessageWithAutograd(
51 rpc::RpcAgent& agent,
52 const rpc::WorkerInfo& dst,
53 c10::intrusive_ptr<rpc::Message> wrappedRpcMsg,
54 bool forceGradRecording = false,
55 const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
56 bool forceDisableProfiling = false);
57
58} // namespace autograd
59} // namespace distributed
60} // namespace torch
61