1#include <torch/csrc/distributed/rpc/utils.h>
2
3#include <fmt/format.h>
4#include <torch/csrc/autograd/profiler.h>
5#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
6#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
7#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
8#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
9#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
10#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
11#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
12#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
13#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h>
14#include <torch/csrc/distributed/autograd/utils.h>
15#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
16#include <torch/csrc/distributed/rpc/python_call.h>
17#include <torch/csrc/distributed/rpc/python_remote_call.h>
18#include <torch/csrc/distributed/rpc/python_resp.h>
19#include <torch/csrc/distributed/rpc/rref_proto.h>
20#include <torch/csrc/distributed/rpc/script_call.h>
21#include <torch/csrc/distributed/rpc/script_remote_call.h>
22#include <torch/csrc/distributed/rpc/script_resp.h>
23#include <torch/csrc/jit/serialization/pickler.h>
24#include <torch/csrc/jit/serialization/unpickler.h>
25
26#include <c10/util/irange.h>
27
28using namespace torch::autograd::profiler;
29
30namespace torch {
31namespace distributed {
32namespace rpc {
33namespace {
34void processRemoteProfiledEvents(
35 autograd::RpcWithProfilingResp& rpcWithProfilingResp) {
36 // Check if the profiler is enabled
37 auto enabled = profilerEnabled();
38 TORCH_CHECK(
39 enabled,
40 "Profiler was expected to be enabled. This can happen in callback "
41 " continuations that run in different threads, and the TLS of the "
42 " profiler was not propagated.");
43 std::vector<LegacyEvent> events = rpcWithProfilingResp.getProfiledEvents();
44 const auto& profilingId = rpcWithProfilingResp.getProfilingId();
45 auto& remoteProfilerManager = RemoteProfilerManager::getInstance();
46 auto key = remoteProfilerManager.retrieveRPCProfilingKey(profilingId);
47 remoteProfilerManager.eraseKey(profilingId);
48 auto keyPrefixStr = key + rpc::REMOTE_PROFILING_KEY_PREFIX;
49 std::for_each(
50 events.begin(), events.end(), [&keyPrefixStr](LegacyEvent& event) {
51 std::string name = keyPrefixStr + std::string(event.name());
52 event.setName(at::StringView(name));
53 });
54 // Add event list to the thread local profiler.
55 addEventList(std::move(events));
56}
57
58} // namespace
59
60const std::string kRPCErrorPrefix = std::string("RPCErr");
61
62RPCErrorType getRPCErrorType(const JitFuture& jitFuture) {
63 TORCH_INTERNAL_ASSERT(
64 jitFuture.hasError(),
65 "JitFuture of Message passed to getRPCErrorType does not have an error.");
66
67 // Attempt to parse for error string given by makeRPCError, otherwise return
68 // unknown error.
69 // Note that this function expects errors formatted with makeRPCError().
70 auto err = jitFuture.tryRetrieveErrorMessage();
71 size_t pos = err.find(kRPCErrorPrefix);
72 if (pos != std::string::npos) {
73 // Parse the RPCErrorType.
74 auto errStartIdx =
75 pos + torch::distributed::rpc::kRPCErrorPrefix.size() + 1;
76 auto errEndIdx = err.find(':', errStartIdx);
77 if (errEndIdx == std::string::npos) {
78 // Indicates error was not formatted correctly.
79 return RPCErrorType::UNKNOWN_ERROR;
80 }
81 auto errStr = err.substr(errStartIdx, errEndIdx - errStartIdx);
82 auto errType = static_cast<RPCErrorType>(std::stoi(errStr));
83 return errType;
84 } else {
85 return RPCErrorType::UNKNOWN_ERROR;
86 }
87}
88
89std::string makeRPCError(
90 const std::string& rpcErrorStr,
91 RPCErrorType errorType) {
92 return fmt::format(
93 "{}:{}:{}",
94 torch::distributed::rpc::kRPCErrorPrefix,
95 static_cast<int>(errorType),
96 rpcErrorStr);
97}
98
99std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
100 switch (request.type()) {
101 case MessageType::SCRIPT_CALL: {
102 return ScriptCall::fromMessage(request);
103 }
104 case MessageType::PYTHON_CALL: {
105 return PythonCall::fromMessage(request);
106 }
107 case MessageType::SCRIPT_REMOTE_CALL: {
108 return ScriptRemoteCall::fromMessage(request);
109 }
110 case MessageType::PYTHON_REMOTE_CALL: {
111 return PythonRemoteCall::fromMessage(request);
112 }
113 case MessageType::SCRIPT_RREF_FETCH_CALL: {
114 return ScriptRRefFetchCall::fromMessage(request);
115 }
116 case MessageType::PYTHON_RREF_FETCH_CALL: {
117 return PythonRRefFetchCall::fromMessage(request);
118 }
119 case MessageType::RREF_USER_DELETE: {
120 return RRefUserDelete::fromMessage(request);
121 }
122 case MessageType::RREF_CHILD_ACCEPT: {
123 return RRefChildAccept::fromMessage(request);
124 }
125 case MessageType::RREF_FORK_REQUEST: {
126 return RRefForkRequest::fromMessage(request);
127 }
128 case MessageType::FORWARD_AUTOGRAD_REQ: {
129 return autograd::RpcWithAutograd::fromMessage(request);
130 }
131 case MessageType::BACKWARD_AUTOGRAD_REQ: {
132 return autograd::PropagateGradientsReq::fromMessage(request);
133 }
134 case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
135 return autograd::CleanupAutogradContextReq::fromMessage(request);
136 }
137 case MessageType::RUN_WITH_PROFILING_REQ: {
138 return autograd::RpcWithProfilingReq::fromMessage(request);
139 }
140 case MessageType::RREF_BACKWARD_REQ: {
141 return autograd::RRefBackwardReq::fromMessage(request);
142 }
143 default: {
144 TORCH_INTERNAL_ASSERT(
145 false, "Request type ", request.type(), " not supported.");
146 }
147 }
148}
149
150std::unique_ptr<RpcCommandBase> deserializeResponse(
151 const Message& response,
152 MessageType& wrappedMsgType) {
153 switch (response.type()) {
154 case MessageType::SCRIPT_RET: {
155 return ScriptResp::fromMessage(response);
156 }
157 case MessageType::PYTHON_RET: {
158 return PythonResp::fromMessage(response);
159 }
160 case MessageType::REMOTE_RET: {
161 return RemoteRet::fromMessage(response);
162 }
163 case MessageType::SCRIPT_RREF_FETCH_RET: {
164 return ScriptRRefFetchRet::fromMessage(response);
165 }
166 case MessageType::PYTHON_RREF_FETCH_RET: {
167 return PythonRRefFetchRet::fromMessage(response);
168 }
169 case MessageType::RREF_ACK: {
170 return RRefAck::fromMessage(response);
171 }
172 case MessageType::FORWARD_AUTOGRAD_RESP: {
173 std::unique_ptr<RpcCommandBase> rpcPtr =
174 autograd::RpcWithAutograd::fromMessage(response);
175 RpcCommandBase& rpc = *rpcPtr;
176 auto& rpcWithAutograd = static_cast<autograd::RpcWithAutograd&>(rpc);
177
178 // Need to reverse the device map for the backward pass of distributed
179 // autograd.
180 DeviceMap reverseDeviceMap;
181 for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
182 reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
183 }
184
185 // Attach 'recv' autograd function.
186 addRecvRpcBackward(
187 rpcWithAutograd.autogradMetadata(),
188 rpcWithAutograd.tensors(),
189 rpcWithAutograd.fromWorkerId(),
190 reverseDeviceMap);
191
192 wrappedMsgType = rpcWithAutograd.wrappedMessageType();
193
194 return std::move(rpcWithAutograd).moveWrappedRpc();
195 }
196 case MessageType::BACKWARD_AUTOGRAD_RESP: {
197 return autograd::PropagateGradientsResp::fromMessage(response);
198 }
199 case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
200 return autograd::CleanupAutogradContextResp::fromMessage(response);
201 }
202 case MessageType::RUN_WITH_PROFILING_RESP: {
203 std::unique_ptr<RpcCommandBase> rpcPtr =
204 autograd::RpcWithProfilingResp::fromMessage(response);
205 RpcCommandBase& rpc = *rpcPtr;
206 auto& rpcWithProfilingResp =
207 static_cast<autograd::RpcWithProfilingResp&>(rpc);
208 // Process remotely profiled events.
209 processRemoteProfiledEvents(rpcWithProfilingResp);
210
211 wrappedMsgType = rpcWithProfilingResp.wrappedMessageType();
212 auto wrappedRPC = std::move(rpcWithProfilingResp).moveWrappedRpc();
213 return wrappedRPC;
214 }
215 case MessageType::RREF_BACKWARD_RESP: {
216 return autograd::RRefBackwardResp::fromMessage(response);
217 }
218 default: {
219 TORCH_INTERNAL_ASSERT(
220 false, "Response type ", response.type(), " not supported.");
221 }
222 }
223}
224
225IValue deserializeResptoIValueInternal(
226 RpcCommandBase& rpc,
227 MessageType messageType) {
228 switch (messageType) {
229 case MessageType::SCRIPT_RET: {
230 auto& ret = static_cast<ScriptResp&>(rpc);
231 return ret.value();
232 }
233 default: {
234 TORCH_INTERNAL_ASSERT(
235 false,
236 "Response type ",
237 messageType,
238 " is not supported to be deserialized to IValue.");
239 }
240 }
241}
242
243IValue deserializeRespToIValue(const Message& message) {
244 MessageType msgType = message.type();
245 auto response = deserializeResponse(message, msgType);
246 return deserializeResptoIValueInternal(*response, msgType);
247}
248
249namespace {
250
251// Helper for wireDeserialize() below.
252//
253// The format we use below looks like:
254// section_name_1 size_1\n
255// section_name_2 size_2\n
256// ..
257// \n
258// [sections in order]
259//
260// Sections themselves include:
261// - "payload" - the payload bits
262// - "meta" - metadata for the unpickler
263// - "0" ... - tensor sections for the unpickler
264//
265// Note that per the header comments, the format is subject to change,
266// and is best used for rpcs, rather than persistent disk storage.
267std::unordered_map<std::string, std::pair<const char*, size_t>>
268parseWireSections(const void* data, size_t data_size) {
269 const char* ptr = static_cast<const char*>(data);
270 const char* endp = ptr + data_size;
271
272 std::vector<std::pair<std::string, size_t>> headerEnts;
273 bool ok = false;
274 while (ptr != endp) {
275 if (*ptr == '\n') {
276 ok = true; // The only "correct" exit point.
277 ++ptr;
278 break;
279 }
280 // Parse name
281 const char* namePtr = ptr;
282 while (ptr != endp && *ptr != ' ') {
283 ptr++;
284 }
285 if (ptr == endp) {
286 break;
287 }
288 std::string name(namePtr, ptr - namePtr);
289 if (++ptr == endp) {
290 break; // past the ' '
291 }
292 // Parse size
293 const char* sizePtr = ptr;
294 while (ptr != endp && *ptr != '\n') {
295 ptr++;
296 }
297 if (ptr == endp) {
298 break;
299 }
300 size_t sz = c10::stoll(std::string(sizePtr, ptr - sizePtr));
301 headerEnts.emplace_back(name, sz);
302 ++ptr; // past the '\n'
303 }
304 if (!ok) {
305 TORCH_CHECK(false, "failed parse");
306 }
307
308 std::unordered_map<std::string, std::pair<const char*, size_t>> out;
309 for (const auto& headerEnt : headerEnts) {
310 out[headerEnt.first] = {ptr, headerEnt.second};
311 ptr += headerEnt.second;
312 }
313 if (ptr != endp) {
314 TORCH_CHECK(false, "failed bounds");
315 }
316 return out;
317}
318
319static const char* kMeta = "meta";
320static const char* kPayload = "payload";
321}; // namespace
322
323c10::List<at::Tensor> cloneSparseTensors(
324 const std::vector<at::Tensor>& tensors) {
325 // Sanity-check: If the majority of bits don't need to go over the wire,
326 // force a clone(). Some Tensors are effectively small views, only using
327 // ~1% of the underlying Storage.
328 auto worthRecopying = [](const at::Tensor& t) -> bool {
329 if (!t.has_storage()) {
330 return false; // avoid throwing below.
331 }
332 auto storageSize = t.storage().nbytes();
333 auto usefulSize = t.element_size() * t.numel();
334 constexpr size_t kMinMultiple = 2;
335 constexpr size_t kMinRecopyBytes = 8 * 1024;
336 return storageSize >= kMinRecopyBytes &&
337 storageSize >= usefulSize * kMinMultiple;
338 };
339 c10::List<at::Tensor> pTensors;
340 pTensors.reserve(tensors.size());
341 for (const auto& t : tensors) {
342 pTensors.push_back(worthRecopying(t) ? t.clone() : t);
343 }
344 return pTensors;
345}
346
347std::string wireSerialize(
348 const std::vector<char>& payload,
349 const std::vector<at::Tensor>& tensors) {
350 for (const auto& tensor : tensors) {
351 TORCH_CHECK(
352 tensor.device().is_cpu(),
353 "ProcessGroup RPC backend only supports",
354 " CPU tensors, please move your tensors to CPU before sending ",
355 "them over RPC. Found tensor on device: ",
356 tensor.device());
357 }
358
359 struct Ent {
360 std::string name;
361 const char* data;
362 size_t size;
363 };
364 std::vector<Ent> entries;
365 std::string metaEntry;
366 std::vector<at::Tensor> tensorData;
367
368 if (!payload.empty()) {
369 entries.push_back({kPayload, payload.data(), payload.size()});
370 }
371
372 if (!tensors.empty()) {
373 torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t {
374 metaEntry.append(static_cast<const char*>(buf), sz);
375 return sz;
376 });
377 pickler.protocol();
378 pickler.pushIValue(cloneSparseTensors(tensors));
379 pickler.stop();
380 tensorData = pickler.tensorData();
381 entries.push_back({kMeta, metaEntry.data(), metaEntry.size()});
382 for (const auto i : c10::irange(tensorData.size())) {
383 // Construct WritableTensorData for each tensor in the pickler tensorData
384 // Since tensorData is in function scope, and getWritableTensorData just
385 // record the tensors, the data() pointers stay valid for CPU tensors
386 // Note that RPC serde doesn't support CUDA tensors yet, if we should
387 // support CUDA tensor, we need to be careful since getWritableTensorData
388 // converts CUDA tensor to cpu and data() might get destructed as we go
389 // out of scope of this loop.
390 auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]);
391 entries.push_back(
392 {c10::to_string(i),
393 writeableTensorData.data(),
394 writeableTensorData.sizeInBytes()});
395 }
396 }
397
398 std::string header;
399 size_t tot = 0;
400 for (const auto& e : entries) {
401 tot += e.size;
402 header.append(e.name)
403 .append(" ")
404 .append(c10::to_string(e.size))
405 .append("\n");
406 }
407 header.push_back('\n');
408
409 std::string out;
410 out.reserve(header.size() + tot);
411 out.append(header);
412 for (const auto& e : entries) {
413 out.append(e.data, e.size);
414 }
415 return out;
416}
417
418std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
419 const void* data,
420 size_t data_size) {
421 auto sections = parseWireSections(data, data_size);
422
423 std::vector<char> payload;
424 auto payloadIt = sections.find(kPayload);
425 if (payloadIt != sections.end() && payloadIt->second.second != 0) {
426 payload.assign(
427 payloadIt->second.first,
428 payloadIt->second.first + payloadIt->second.second);
429 }
430
431 std::vector<at::Tensor> tensors;
432 auto metaIt = sections.find(kMeta);
433 if (metaIt != sections.end()) {
434 const auto& metaData = metaIt->second;
435 size_t metaDataPos = 0;
436 auto metaDataReadFunc = [&](char* buf, size_t n) -> size_t {
437 if (metaDataPos >= metaData.second || n == 0) {
438 return 0;
439 }
440 size_t toCopy = std::min(metaDataPos + n, metaData.second) - metaDataPos;
441 memcpy(buf, metaData.first + metaDataPos, toCopy);
442 metaDataPos += toCopy;
443 return toCopy;
444 };
445 auto sectionReadFunc = [&](const std::string& ename) -> at::DataPtr {
446 auto it = sections.find(ename);
447 if (it == sections.end()) {
448 TORCH_CHECK(false, "Couldn't find entity " + ename);
449 }
450 const auto& idat = it->second;
451 auto dptr = at::getCPUAllocator()->allocate(idat.second);
452 if (idat.second != 0) {
453 memcpy(dptr.get(), idat.first, idat.second);
454 }
455 return dptr;
456 };
457
458 // No need to pass typeResolver here, as it always processes string and
459 // tensors only
460 torch::jit::Unpickler unpickler(
461 metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {});
462 auto ival = unpickler.parse_ivalue();
463 for (auto&& t : ival.toTensorList()) {
464 tensors.emplace_back(std::move(t));
465 }
466 }
467 return {std::move(payload), std::move(tensors)};
468}
469
470void writeWrappedPayload(
471 std::vector<char>& originalPayload,
472 std::vector<char>& additionalPayload) {
473 originalPayload.insert(
474 originalPayload.end(),
475 additionalPayload.begin(),
476 additionalPayload.end());
477
478 // Add size of the additional payload
479 int64_t indexToWrite = originalPayload.size();
480 originalPayload.resize(originalPayload.size() + sizeof(int64_t));
481 const int64_t additionalPayloadSize = additionalPayload.size();
482 torch::utils::THP_encodeInt64Buffer(
483 reinterpret_cast<uint8_t*>(originalPayload.data()) + indexToWrite,
484 &additionalPayloadSize,
485 torch::utils::THPByteOrder::THP_BIG_ENDIAN,
486 1);
487}
488
489std::vector<at::IValue> readWrappedPayload(
490 std::vector<char>& payload,
491 const rpc::Message& message) {
492 // Read the additional payload remove it from the payload.
493 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
494 int64_t additionalPayloadSize;
495 TORCH_INTERNAL_ASSERT(payload.size() >= sizeof(int64_t));
496 size_t indexToRead = payload.size() - sizeof(int64_t);
497 torch::utils::THP_decodeInt64Buffer(
498 &additionalPayloadSize,
499 reinterpret_cast<uint8_t*>(payload.data()) + indexToRead,
500 torch::utils::THPByteOrder::THP_BIG_ENDIAN,
501 1);
502 payload.resize(indexToRead);
503
504 TORCH_INTERNAL_ASSERT(
505 // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
506 payload.size() > additionalPayloadSize,
507 "Wrong payload sizes: payload.size() is ",
508 payload.size(),
509 " but additional payload size is ",
510 additionalPayloadSize);
511 auto wrappedPayloadBegin =
512 static_cast<const char*>(message.payload().data()) + payload.size() -
513 additionalPayloadSize;
514 std::vector<torch::Tensor> tensorTable;
515 IValue tuple = jit::unpickle(
516 wrappedPayloadBegin,
517 additionalPayloadSize,
518 *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
519 tensorTable);
520 std::vector<at::IValue> tupleElements = tuple.toTupleRef().elements().vec();
521 payload.resize(payload.size() - additionalPayloadSize);
522 return tupleElements;
523}
524
525void populateRemoteProfiledEvents(
526 std::vector<LegacyEvent>& profiledEvents,
527 const ProfilerConfig& profilingConfig,
528 const std::vector<std::vector<LegacyEvent>>& eventLists) {
529 // Gather all events into a vector
530 for (auto& l : eventLists) {
531 for (auto& e : l) {
532 profiledEvents.push_back(e);
533 }
534 }
535 // find __start_profile event
536 bool cudaProfilingEnabled = profilingConfig.state == ProfilerState::CUDA;
537 const LegacyEvent* profilerStart = nullptr;
538
539 for (auto& e : profiledEvents) {
540 if (std::string(e.name()) == "__start_profile") {
541 profilerStart = &e;
542 break;
543 }
544 }
545 // We should always find __start_profile.
546 TORCH_CHECK(
547 profilerStart != nullptr, "Expected to find __start_profile event.");
548
549 if (cudaProfilingEnabled) {
550 // Deserialized events don't have the corresponding CUDA events, making it
551 // impossible to use cudaEventElapsedTime the receiving end. To avoid this,
552 // find all push/pop pairs of CUDA events and set the corresponding CUDA
553 // time to zero for the push event and to the elapsed time for the pop
554 // event, to be used later for the elapsed CUDA time computation.
555 std::unordered_map<at::RecordFunctionHandle, const LegacyEvent*>
556 startEvents;
557 for (auto& e : profiledEvents) {
558 if (e.hasCuda()) {
559 if (e.kind() == EventKind::PushRange) {
560 startEvents[e.handle()] = &e;
561 }
562 }
563 }
564 for (auto& e : profiledEvents) {
565 if (e.hasCuda()) {
566 if (e.kind() == EventKind::PopRange) {
567 auto it = startEvents.find(e.handle());
568 if (it != startEvents.end()) {
569 e.setCudaUs(it->second->cudaElapsedUs(e));
570 } else {
571 TORCH_WARN("Found a pop event without a corresponding push event");
572 e.setCudaUs(0);
573 }
574 } else {
575 e.setCudaUs(0);
576 }
577 }
578 }
579 }
580}
581
582} // namespace rpc
583} // namespace distributed
584} // namespace torch
585