1#include <ATen/ThreadLocalState.h>
2#include <c10/util/ThreadLocalDebugInfo.h>
3#include <torch/csrc/autograd/functions/utils.h>
4#include <torch/csrc/autograd/profiler.h>
5#include <torch/csrc/distributed/autograd/context/container.h>
6#include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
7#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
8#include <torch/csrc/distributed/autograd/utils.h>
9#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
10#include <torch/csrc/distributed/rpc/rpc_agent.h>
11#include <torch/csrc/distributed/rpc/types.h>
12
13namespace torch {
14namespace distributed {
15namespace autograd {
16
17using torch::distributed::autograd::AutogradMetadata;
18using torch::distributed::autograd::RpcWithAutograd;
19using torch::distributed::rpc::JitFuture;
20using torch::distributed::rpc::Message;
21using torch::distributed::rpc::MessageType;
22using torch::distributed::rpc::RpcAgent;
23using torch::distributed::rpc::WorkerInfo;
24
25void addSendRpcBackward(
26 const ContextPtr& autogradContext,
27 const AutogradMetadata& autogradMetadata,
28 std::vector<torch::Tensor>& tensors) {
29 // Attach autograd information only for tensors requiring grad.
30 std::vector<torch::Tensor> tensors_with_grad;
31 std::copy_if(
32 tensors.begin(),
33 tensors.end(),
34 std::back_inserter(tensors_with_grad),
35 [](const torch::Tensor& t) { return t.requires_grad(); });
36
37 // Attach the appropriate autograd edges.
38 auto grad_fn = std::make_shared<SendRpcBackward>();
39 grad_fn->set_next_edges(
40 torch::autograd::collect_next_edges(tensors_with_grad));
41
42 // Add the appropriate input metadata for the grad_fn.
43 for (const auto& tensor : tensors_with_grad) {
44 grad_fn->add_input_metadata(tensor);
45 }
46
47 // Record the send autograd function in our current context.
48 autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
49}
50
51ContextPtr addRecvRpcBackward(
52 const AutogradMetadata& autogradMetadata,
53 std::vector<torch::Tensor>& tensors,
54 rpc::worker_id_t fromWorkerId,
55 const rpc::DeviceMap& deviceMap) {
56 // Initialize autograd context if necessary.
57 auto& autogradContainer = DistAutogradContainer::getInstance();
58 auto autogradContext =
59 autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
60
61 if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
62 // Attach the tensors as inputs to the autograd function.
63 auto grad_fn = std::make_shared<RecvRpcBackward>(
64 autogradMetadata, autogradContext, fromWorkerId, deviceMap);
65 for (auto& tensor : tensors) {
66 if (tensor.requires_grad()) {
67 torch::autograd::set_history(tensor, grad_fn);
68 }
69 }
70
71 // Now update the autograd context with the necessary information.
72 autogradContext->addRecvFunction(
73 grad_fn, autogradMetadata.autogradMessageId);
74 }
75
76 return autogradContext;
77}
78
79c10::intrusive_ptr<Message> getMessageWithProfiling(
80 c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMessage,
81 MessageType msgType,
82 torch::autograd::profiler::ProfilerConfig&& profilerConfig) {
83 auto& remoteProfilerManager =
84 torch::distributed::rpc::RemoteProfilerManager::getInstance();
85
86 auto key = remoteProfilerManager.getCurrentProfilingKey();
87 // generate a globally unique Id
88 auto globallyUniqueProfilingId = remoteProfilerManager.getNextProfilerId();
89 // Save a mapping of ID -> RPC profiling key and unset the current TLS key.
90 remoteProfilerManager.saveRPCKey(globallyUniqueProfilingId, key);
91 remoteProfilerManager.unsetCurrentKey();
92 auto wrappedProfilingMsg = RpcWithProfilingReq(
93 msgType,
94 std::move(wrappedRpcMessage),
95 std::move(profilerConfig),
96 globallyUniqueProfilingId);
97
98 return std::move(wrappedProfilingMsg).toMessage();
99}
100
101c10::intrusive_ptr<Message> getMessageWithAutograd(
102 const rpc::worker_id_t dstId,
103 c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
104 MessageType msgType,
105 bool forceGradRecording,
106 const rpc::DeviceMap& deviceMap) {
107 auto& autogradContainer = DistAutogradContainer::getInstance();
108
109 // If there is no valid context and no tensor requires grads, send original
110 // rpc message. otherwise, attach grad info and grad functions and send
111 // rpcWithAutograd message.
112 auto tensorsRequireGrad =
113 torch::autograd::compute_requires_grad(wrappedRpcMsg->tensors());
114 if (!autogradContainer.hasValidContext() ||
115 (!forceGradRecording && !tensorsRequireGrad)) {
116 return wrappedRpcMsg;
117 }
118
119 // Retrieve the appropriate context to modify.
120 auto autogradContext = autogradContainer.currentContext();
121
122 // Wrap the original rpc with autograd information.
123 AutogradMetadata autogradMetadata(
124 autogradContext->contextId(), autogradContainer.newAutogradMessageId());
125 auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
126 RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
127 msgType,
128 autogradMetadata,
129 std::move(wrappedRpcMsg),
130 deviceMap);
131
132 if (tensorsRequireGrad) {
133 // Record autograd information for 'send'.
134 addSendRpcBackward(
135 autogradContext, autogradMetadata, rpcWithAutograd->tensors());
136 }
137 // Record the workerID
138 autogradContext->addKnownWorkerId(dstId);
139
140 return std::move(*rpcWithAutograd).toMessage();
141}
142
143c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
144 RpcAgent& agent,
145 const WorkerInfo& dst,
146 c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
147 bool forceGradRecording,
148 const float rpcTimeoutSeconds,
149 bool forceDisableProfiling) {
150 auto msg = getMessageWithAutograd(
151 dst.id_,
152 std::move(wrappedRpcMsg),
153 MessageType::FORWARD_AUTOGRAD_REQ,
154 forceGradRecording,
155 agent.getDeviceMap(dst));
156
157 // If profiler is enabled, wrap this message with profiling metadata that will
158 // tell the remote end to process this request with the profiler enabled.
159 if (!forceDisableProfiling) {
160 switch (torch::profiler::impl::profilerType()) {
161 case torch::profiler::impl::ActiveProfilerType::LEGACY: {
162 auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
163 auto msgWithProfiling = getMessageWithProfiling(
164 std::move(msg),
165 rpc::MessageType::RUN_WITH_PROFILING_REQ,
166 std::move(profilerConfig));
167 return agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
168 }
169 case torch::profiler::impl::ActiveProfilerType::KINETO:
170 TORCH_WARN_ONCE(
171 "Profiling a distributed call with the Kineto profiler will profile "
172 "the caller, but not the worker.");
173 break;
174 default:
175 break;
176 }
177 }
178
179 return agent.send(dst, std::move(msg), rpcTimeoutSeconds);
180 ;
181}
182
183} // namespace autograd
184} // namespace distributed
185} // namespace torch
186