1 | #include <ATen/ThreadLocalState.h> |
2 | #include <c10/util/ThreadLocalDebugInfo.h> |
3 | #include <torch/csrc/autograd/functions/utils.h> |
4 | #include <torch/csrc/autograd/profiler.h> |
5 | #include <torch/csrc/distributed/autograd/context/container.h> |
6 | #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h> |
7 | #include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h> |
8 | #include <torch/csrc/distributed/autograd/utils.h> |
9 | #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> |
10 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
11 | #include <torch/csrc/distributed/rpc/types.h> |
12 | |
13 | namespace torch { |
14 | namespace distributed { |
15 | namespace autograd { |
16 | |
17 | using torch::distributed::autograd::AutogradMetadata; |
18 | using torch::distributed::autograd::RpcWithAutograd; |
19 | using torch::distributed::rpc::JitFuture; |
20 | using torch::distributed::rpc::Message; |
21 | using torch::distributed::rpc::MessageType; |
22 | using torch::distributed::rpc::RpcAgent; |
23 | using torch::distributed::rpc::WorkerInfo; |
24 | |
25 | void addSendRpcBackward( |
26 | const ContextPtr& autogradContext, |
27 | const AutogradMetadata& autogradMetadata, |
28 | std::vector<torch::Tensor>& tensors) { |
29 | // Attach autograd information only for tensors requiring grad. |
30 | std::vector<torch::Tensor> tensors_with_grad; |
31 | std::copy_if( |
32 | tensors.begin(), |
33 | tensors.end(), |
34 | std::back_inserter(tensors_with_grad), |
35 | [](const torch::Tensor& t) { return t.requires_grad(); }); |
36 | |
37 | // Attach the appropriate autograd edges. |
38 | auto grad_fn = std::make_shared<SendRpcBackward>(); |
39 | grad_fn->set_next_edges( |
40 | torch::autograd::collect_next_edges(tensors_with_grad)); |
41 | |
42 | // Add the appropriate input metadata for the grad_fn. |
43 | for (const auto& tensor : tensors_with_grad) { |
44 | grad_fn->add_input_metadata(tensor); |
45 | } |
46 | |
47 | // Record the send autograd function in our current context. |
48 | autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId); |
49 | } |
50 | |
51 | ContextPtr addRecvRpcBackward( |
52 | const AutogradMetadata& autogradMetadata, |
53 | std::vector<torch::Tensor>& tensors, |
54 | rpc::worker_id_t fromWorkerId, |
55 | const rpc::DeviceMap& deviceMap) { |
56 | // Initialize autograd context if necessary. |
57 | auto& autogradContainer = DistAutogradContainer::getInstance(); |
58 | auto autogradContext = |
59 | autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId); |
60 | |
61 | if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) { |
62 | // Attach the tensors as inputs to the autograd function. |
63 | auto grad_fn = std::make_shared<RecvRpcBackward>( |
64 | autogradMetadata, autogradContext, fromWorkerId, deviceMap); |
65 | for (auto& tensor : tensors) { |
66 | if (tensor.requires_grad()) { |
67 | torch::autograd::set_history(tensor, grad_fn); |
68 | } |
69 | } |
70 | |
71 | // Now update the autograd context with the necessary information. |
72 | autogradContext->addRecvFunction( |
73 | grad_fn, autogradMetadata.autogradMessageId); |
74 | } |
75 | |
76 | return autogradContext; |
77 | } |
78 | |
79 | c10::intrusive_ptr<Message> getMessageWithProfiling( |
80 | c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMessage, |
81 | MessageType msgType, |
82 | torch::autograd::profiler::ProfilerConfig&& profilerConfig) { |
83 | auto& remoteProfilerManager = |
84 | torch::distributed::rpc::RemoteProfilerManager::getInstance(); |
85 | |
86 | auto key = remoteProfilerManager.getCurrentProfilingKey(); |
87 | // generate a globally unique Id |
88 | auto globallyUniqueProfilingId = remoteProfilerManager.getNextProfilerId(); |
89 | // Save a mapping of ID -> RPC profiling key and unset the current TLS key. |
90 | remoteProfilerManager.saveRPCKey(globallyUniqueProfilingId, key); |
91 | remoteProfilerManager.unsetCurrentKey(); |
92 | auto wrappedProfilingMsg = RpcWithProfilingReq( |
93 | msgType, |
94 | std::move(wrappedRpcMessage), |
95 | std::move(profilerConfig), |
96 | globallyUniqueProfilingId); |
97 | |
98 | return std::move(wrappedProfilingMsg).toMessage(); |
99 | } |
100 | |
101 | c10::intrusive_ptr<Message> getMessageWithAutograd( |
102 | const rpc::worker_id_t dstId, |
103 | c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg, |
104 | MessageType msgType, |
105 | bool forceGradRecording, |
106 | const rpc::DeviceMap& deviceMap) { |
107 | auto& autogradContainer = DistAutogradContainer::getInstance(); |
108 | |
109 | // If there is no valid context and no tensor requires grads, send original |
110 | // rpc message. otherwise, attach grad info and grad functions and send |
111 | // rpcWithAutograd message. |
112 | auto tensorsRequireGrad = |
113 | torch::autograd::compute_requires_grad(wrappedRpcMsg->tensors()); |
114 | if (!autogradContainer.hasValidContext() || |
115 | (!forceGradRecording && !tensorsRequireGrad)) { |
116 | return wrappedRpcMsg; |
117 | } |
118 | |
119 | // Retrieve the appropriate context to modify. |
120 | auto autogradContext = autogradContainer.currentContext(); |
121 | |
122 | // Wrap the original rpc with autograd information. |
123 | AutogradMetadata autogradMetadata( |
124 | autogradContext->contextId(), autogradContainer.newAutogradMessageId()); |
125 | auto rpcWithAutograd = std::make_unique<RpcWithAutograd>( |
126 | RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_, |
127 | msgType, |
128 | autogradMetadata, |
129 | std::move(wrappedRpcMsg), |
130 | deviceMap); |
131 | |
132 | if (tensorsRequireGrad) { |
133 | // Record autograd information for 'send'. |
134 | addSendRpcBackward( |
135 | autogradContext, autogradMetadata, rpcWithAutograd->tensors()); |
136 | } |
137 | // Record the workerID |
138 | autogradContext->addKnownWorkerId(dstId); |
139 | |
140 | return std::move(*rpcWithAutograd).toMessage(); |
141 | } |
142 | |
143 | c10::intrusive_ptr<JitFuture> sendMessageWithAutograd( |
144 | RpcAgent& agent, |
145 | const WorkerInfo& dst, |
146 | c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg, |
147 | bool forceGradRecording, |
148 | const float rpcTimeoutSeconds, |
149 | bool forceDisableProfiling) { |
150 | auto msg = getMessageWithAutograd( |
151 | dst.id_, |
152 | std::move(wrappedRpcMsg), |
153 | MessageType::FORWARD_AUTOGRAD_REQ, |
154 | forceGradRecording, |
155 | agent.getDeviceMap(dst)); |
156 | |
157 | // If profiler is enabled, wrap this message with profiling metadata that will |
158 | // tell the remote end to process this request with the profiler enabled. |
159 | if (!forceDisableProfiling) { |
160 | switch (torch::profiler::impl::profilerType()) { |
161 | case torch::profiler::impl::ActiveProfilerType::LEGACY: { |
162 | auto profilerConfig = torch::autograd::profiler::getProfilerConfig(); |
163 | auto msgWithProfiling = getMessageWithProfiling( |
164 | std::move(msg), |
165 | rpc::MessageType::RUN_WITH_PROFILING_REQ, |
166 | std::move(profilerConfig)); |
167 | return agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds); |
168 | } |
169 | case torch::profiler::impl::ActiveProfilerType::KINETO: |
170 | TORCH_WARN_ONCE( |
171 | "Profiling a distributed call with the Kineto profiler will profile " |
172 | "the caller, but not the worker." ); |
173 | break; |
174 | default: |
175 | break; |
176 | } |
177 | } |
178 | |
179 | return agent.send(dst, std::move(msg), rpcTimeoutSeconds); |
180 | ; |
181 | } |
182 | |
183 | } // namespace autograd |
184 | } // namespace distributed |
185 | } // namespace torch |
186 | |