1 | #include <torch/csrc/distributed/rpc/utils.h> |
2 | |
3 | #include <fmt/format.h> |
4 | #include <torch/csrc/autograd/profiler.h> |
5 | #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h> |
6 | #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h> |
7 | #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h> |
8 | #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h> |
9 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h> |
10 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h> |
11 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h> |
12 | #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h> |
13 | #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h> |
14 | #include <torch/csrc/distributed/autograd/utils.h> |
15 | #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> |
16 | #include <torch/csrc/distributed/rpc/python_call.h> |
17 | #include <torch/csrc/distributed/rpc/python_remote_call.h> |
18 | #include <torch/csrc/distributed/rpc/python_resp.h> |
19 | #include <torch/csrc/distributed/rpc/rref_proto.h> |
20 | #include <torch/csrc/distributed/rpc/script_call.h> |
21 | #include <torch/csrc/distributed/rpc/script_remote_call.h> |
22 | #include <torch/csrc/distributed/rpc/script_resp.h> |
23 | #include <torch/csrc/jit/serialization/pickler.h> |
24 | #include <torch/csrc/jit/serialization/unpickler.h> |
25 | |
26 | #include <c10/util/irange.h> |
27 | |
28 | using namespace torch::autograd::profiler; |
29 | |
30 | namespace torch { |
31 | namespace distributed { |
32 | namespace rpc { |
33 | namespace { |
34 | void processRemoteProfiledEvents( |
35 | autograd::RpcWithProfilingResp& rpcWithProfilingResp) { |
36 | // Check if the profiler is enabled |
37 | auto enabled = profilerEnabled(); |
38 | TORCH_CHECK( |
39 | enabled, |
40 | "Profiler was expected to be enabled. This can happen in callback " |
41 | " continuations that run in different threads, and the TLS of the " |
42 | " profiler was not propagated." ); |
43 | std::vector<LegacyEvent> events = rpcWithProfilingResp.getProfiledEvents(); |
44 | const auto& profilingId = rpcWithProfilingResp.getProfilingId(); |
45 | auto& remoteProfilerManager = RemoteProfilerManager::getInstance(); |
46 | auto key = remoteProfilerManager.retrieveRPCProfilingKey(profilingId); |
47 | remoteProfilerManager.eraseKey(profilingId); |
48 | auto keyPrefixStr = key + rpc::REMOTE_PROFILING_KEY_PREFIX; |
49 | std::for_each( |
50 | events.begin(), events.end(), [&keyPrefixStr](LegacyEvent& event) { |
51 | std::string name = keyPrefixStr + std::string(event.name()); |
52 | event.setName(at::StringView(name)); |
53 | }); |
54 | // Add event list to the thread local profiler. |
55 | addEventList(std::move(events)); |
56 | } |
57 | |
58 | } // namespace |
59 | |
60 | const std::string kRPCErrorPrefix = std::string("RPCErr" ); |
61 | |
62 | RPCErrorType getRPCErrorType(const JitFuture& jitFuture) { |
63 | TORCH_INTERNAL_ASSERT( |
64 | jitFuture.hasError(), |
65 | "JitFuture of Message passed to getRPCErrorType does not have an error." ); |
66 | |
67 | // Attempt to parse for error string given by makeRPCError, otherwise return |
68 | // unknown error. |
69 | // Note that this function expects errors formatted with makeRPCError(). |
70 | auto err = jitFuture.tryRetrieveErrorMessage(); |
71 | size_t pos = err.find(kRPCErrorPrefix); |
72 | if (pos != std::string::npos) { |
73 | // Parse the RPCErrorType. |
74 | auto errStartIdx = |
75 | pos + torch::distributed::rpc::kRPCErrorPrefix.size() + 1; |
76 | auto errEndIdx = err.find(':', errStartIdx); |
77 | if (errEndIdx == std::string::npos) { |
78 | // Indicates error was not formatted correctly. |
79 | return RPCErrorType::UNKNOWN_ERROR; |
80 | } |
81 | auto errStr = err.substr(errStartIdx, errEndIdx - errStartIdx); |
82 | auto errType = static_cast<RPCErrorType>(std::stoi(errStr)); |
83 | return errType; |
84 | } else { |
85 | return RPCErrorType::UNKNOWN_ERROR; |
86 | } |
87 | } |
88 | |
89 | std::string makeRPCError( |
90 | const std::string& rpcErrorStr, |
91 | RPCErrorType errorType) { |
92 | return fmt::format( |
93 | "{}:{}:{}" , |
94 | torch::distributed::rpc::kRPCErrorPrefix, |
95 | static_cast<int>(errorType), |
96 | rpcErrorStr); |
97 | } |
98 | |
99 | std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) { |
100 | switch (request.type()) { |
101 | case MessageType::SCRIPT_CALL: { |
102 | return ScriptCall::fromMessage(request); |
103 | } |
104 | case MessageType::PYTHON_CALL: { |
105 | return PythonCall::fromMessage(request); |
106 | } |
107 | case MessageType::SCRIPT_REMOTE_CALL: { |
108 | return ScriptRemoteCall::fromMessage(request); |
109 | } |
110 | case MessageType::PYTHON_REMOTE_CALL: { |
111 | return PythonRemoteCall::fromMessage(request); |
112 | } |
113 | case MessageType::SCRIPT_RREF_FETCH_CALL: { |
114 | return ScriptRRefFetchCall::fromMessage(request); |
115 | } |
116 | case MessageType::PYTHON_RREF_FETCH_CALL: { |
117 | return PythonRRefFetchCall::fromMessage(request); |
118 | } |
119 | case MessageType::RREF_USER_DELETE: { |
120 | return RRefUserDelete::fromMessage(request); |
121 | } |
122 | case MessageType::RREF_CHILD_ACCEPT: { |
123 | return RRefChildAccept::fromMessage(request); |
124 | } |
125 | case MessageType::RREF_FORK_REQUEST: { |
126 | return RRefForkRequest::fromMessage(request); |
127 | } |
128 | case MessageType::FORWARD_AUTOGRAD_REQ: { |
129 | return autograd::RpcWithAutograd::fromMessage(request); |
130 | } |
131 | case MessageType::BACKWARD_AUTOGRAD_REQ: { |
132 | return autograd::PropagateGradientsReq::fromMessage(request); |
133 | } |
134 | case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: { |
135 | return autograd::CleanupAutogradContextReq::fromMessage(request); |
136 | } |
137 | case MessageType::RUN_WITH_PROFILING_REQ: { |
138 | return autograd::RpcWithProfilingReq::fromMessage(request); |
139 | } |
140 | case MessageType::RREF_BACKWARD_REQ: { |
141 | return autograd::RRefBackwardReq::fromMessage(request); |
142 | } |
143 | default: { |
144 | TORCH_INTERNAL_ASSERT( |
145 | false, "Request type " , request.type(), " not supported." ); |
146 | } |
147 | } |
148 | } |
149 | |
150 | std::unique_ptr<RpcCommandBase> deserializeResponse( |
151 | const Message& response, |
152 | MessageType& wrappedMsgType) { |
153 | switch (response.type()) { |
154 | case MessageType::SCRIPT_RET: { |
155 | return ScriptResp::fromMessage(response); |
156 | } |
157 | case MessageType::PYTHON_RET: { |
158 | return PythonResp::fromMessage(response); |
159 | } |
160 | case MessageType::REMOTE_RET: { |
161 | return RemoteRet::fromMessage(response); |
162 | } |
163 | case MessageType::SCRIPT_RREF_FETCH_RET: { |
164 | return ScriptRRefFetchRet::fromMessage(response); |
165 | } |
166 | case MessageType::PYTHON_RREF_FETCH_RET: { |
167 | return PythonRRefFetchRet::fromMessage(response); |
168 | } |
169 | case MessageType::RREF_ACK: { |
170 | return RRefAck::fromMessage(response); |
171 | } |
172 | case MessageType::FORWARD_AUTOGRAD_RESP: { |
173 | std::unique_ptr<RpcCommandBase> rpcPtr = |
174 | autograd::RpcWithAutograd::fromMessage(response); |
175 | RpcCommandBase& rpc = *rpcPtr; |
176 | auto& rpcWithAutograd = static_cast<autograd::RpcWithAutograd&>(rpc); |
177 | |
178 | // Need to reverse the device map for the backward pass of distributed |
179 | // autograd. |
180 | DeviceMap reverseDeviceMap; |
181 | for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { |
182 | reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); |
183 | } |
184 | |
185 | // Attach 'recv' autograd function. |
186 | addRecvRpcBackward( |
187 | rpcWithAutograd.autogradMetadata(), |
188 | rpcWithAutograd.tensors(), |
189 | rpcWithAutograd.fromWorkerId(), |
190 | reverseDeviceMap); |
191 | |
192 | wrappedMsgType = rpcWithAutograd.wrappedMessageType(); |
193 | |
194 | return std::move(rpcWithAutograd).moveWrappedRpc(); |
195 | } |
196 | case MessageType::BACKWARD_AUTOGRAD_RESP: { |
197 | return autograd::PropagateGradientsResp::fromMessage(response); |
198 | } |
199 | case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: { |
200 | return autograd::CleanupAutogradContextResp::fromMessage(response); |
201 | } |
202 | case MessageType::RUN_WITH_PROFILING_RESP: { |
203 | std::unique_ptr<RpcCommandBase> rpcPtr = |
204 | autograd::RpcWithProfilingResp::fromMessage(response); |
205 | RpcCommandBase& rpc = *rpcPtr; |
206 | auto& rpcWithProfilingResp = |
207 | static_cast<autograd::RpcWithProfilingResp&>(rpc); |
208 | // Process remotely profiled events. |
209 | processRemoteProfiledEvents(rpcWithProfilingResp); |
210 | |
211 | wrappedMsgType = rpcWithProfilingResp.wrappedMessageType(); |
212 | auto wrappedRPC = std::move(rpcWithProfilingResp).moveWrappedRpc(); |
213 | return wrappedRPC; |
214 | } |
215 | case MessageType::RREF_BACKWARD_RESP: { |
216 | return autograd::RRefBackwardResp::fromMessage(response); |
217 | } |
218 | default: { |
219 | TORCH_INTERNAL_ASSERT( |
220 | false, "Response type " , response.type(), " not supported." ); |
221 | } |
222 | } |
223 | } |
224 | |
225 | IValue deserializeResptoIValueInternal( |
226 | RpcCommandBase& rpc, |
227 | MessageType messageType) { |
228 | switch (messageType) { |
229 | case MessageType::SCRIPT_RET: { |
230 | auto& ret = static_cast<ScriptResp&>(rpc); |
231 | return ret.value(); |
232 | } |
233 | default: { |
234 | TORCH_INTERNAL_ASSERT( |
235 | false, |
236 | "Response type " , |
237 | messageType, |
238 | " is not supported to be deserialized to IValue." ); |
239 | } |
240 | } |
241 | } |
242 | |
243 | IValue deserializeRespToIValue(const Message& message) { |
244 | MessageType msgType = message.type(); |
245 | auto response = deserializeResponse(message, msgType); |
246 | return deserializeResptoIValueInternal(*response, msgType); |
247 | } |
248 | |
249 | namespace { |
250 | |
251 | // Helper for wireDeserialize() below. |
252 | // |
253 | // The format we use below looks like: |
254 | // section_name_1 size_1\n |
255 | // section_name_2 size_2\n |
256 | // .. |
257 | // \n |
258 | // [sections in order] |
259 | // |
260 | // Sections themselves include: |
261 | // - "payload" - the payload bits |
262 | // - "meta" - metadata for the unpickler |
263 | // - "0" ... - tensor sections for the unpickler |
264 | // |
265 | // Note that per the header comments, the format is subject to change, |
266 | // and is best used for rpcs, rather than persistent disk storage. |
267 | std::unordered_map<std::string, std::pair<const char*, size_t>> |
268 | parseWireSections(const void* data, size_t data_size) { |
269 | const char* ptr = static_cast<const char*>(data); |
270 | const char* endp = ptr + data_size; |
271 | |
272 | std::vector<std::pair<std::string, size_t>> ; |
273 | bool ok = false; |
274 | while (ptr != endp) { |
275 | if (*ptr == '\n') { |
276 | ok = true; // The only "correct" exit point. |
277 | ++ptr; |
278 | break; |
279 | } |
280 | // Parse name |
281 | const char* namePtr = ptr; |
282 | while (ptr != endp && *ptr != ' ') { |
283 | ptr++; |
284 | } |
285 | if (ptr == endp) { |
286 | break; |
287 | } |
288 | std::string name(namePtr, ptr - namePtr); |
289 | if (++ptr == endp) { |
290 | break; // past the ' ' |
291 | } |
292 | // Parse size |
293 | const char* sizePtr = ptr; |
294 | while (ptr != endp && *ptr != '\n') { |
295 | ptr++; |
296 | } |
297 | if (ptr == endp) { |
298 | break; |
299 | } |
300 | size_t sz = c10::stoll(std::string(sizePtr, ptr - sizePtr)); |
301 | headerEnts.emplace_back(name, sz); |
302 | ++ptr; // past the '\n' |
303 | } |
304 | if (!ok) { |
305 | TORCH_CHECK(false, "failed parse" ); |
306 | } |
307 | |
308 | std::unordered_map<std::string, std::pair<const char*, size_t>> out; |
309 | for (const auto& : headerEnts) { |
310 | out[headerEnt.first] = {ptr, headerEnt.second}; |
311 | ptr += headerEnt.second; |
312 | } |
313 | if (ptr != endp) { |
314 | TORCH_CHECK(false, "failed bounds" ); |
315 | } |
316 | return out; |
317 | } |
318 | |
319 | static const char* kMeta = "meta" ; |
320 | static const char* kPayload = "payload" ; |
321 | }; // namespace |
322 | |
323 | c10::List<at::Tensor> cloneSparseTensors( |
324 | const std::vector<at::Tensor>& tensors) { |
325 | // Sanity-check: If the majority of bits don't need to go over the wire, |
326 | // force a clone(). Some Tensors are effectively small views, only using |
327 | // ~1% of the underlying Storage. |
328 | auto worthRecopying = [](const at::Tensor& t) -> bool { |
329 | if (!t.has_storage()) { |
330 | return false; // avoid throwing below. |
331 | } |
332 | auto storageSize = t.storage().nbytes(); |
333 | auto usefulSize = t.element_size() * t.numel(); |
334 | constexpr size_t kMinMultiple = 2; |
335 | constexpr size_t kMinRecopyBytes = 8 * 1024; |
336 | return storageSize >= kMinRecopyBytes && |
337 | storageSize >= usefulSize * kMinMultiple; |
338 | }; |
339 | c10::List<at::Tensor> pTensors; |
340 | pTensors.reserve(tensors.size()); |
341 | for (const auto& t : tensors) { |
342 | pTensors.push_back(worthRecopying(t) ? t.clone() : t); |
343 | } |
344 | return pTensors; |
345 | } |
346 | |
347 | std::string wireSerialize( |
348 | const std::vector<char>& payload, |
349 | const std::vector<at::Tensor>& tensors) { |
350 | for (const auto& tensor : tensors) { |
351 | TORCH_CHECK( |
352 | tensor.device().is_cpu(), |
353 | "ProcessGroup RPC backend only supports" , |
354 | " CPU tensors, please move your tensors to CPU before sending " , |
355 | "them over RPC. Found tensor on device: " , |
356 | tensor.device()); |
357 | } |
358 | |
359 | struct Ent { |
360 | std::string name; |
361 | const char* data; |
362 | size_t size; |
363 | }; |
364 | std::vector<Ent> entries; |
365 | std::string metaEntry; |
366 | std::vector<at::Tensor> tensorData; |
367 | |
368 | if (!payload.empty()) { |
369 | entries.push_back({kPayload, payload.data(), payload.size()}); |
370 | } |
371 | |
372 | if (!tensors.empty()) { |
373 | torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t { |
374 | metaEntry.append(static_cast<const char*>(buf), sz); |
375 | return sz; |
376 | }); |
377 | pickler.protocol(); |
378 | pickler.pushIValue(cloneSparseTensors(tensors)); |
379 | pickler.stop(); |
380 | tensorData = pickler.tensorData(); |
381 | entries.push_back({kMeta, metaEntry.data(), metaEntry.size()}); |
382 | for (const auto i : c10::irange(tensorData.size())) { |
383 | // Construct WritableTensorData for each tensor in the pickler tensorData |
384 | // Since tensorData is in function scope, and getWritableTensorData just |
385 | // record the tensors, the data() pointers stay valid for CPU tensors |
386 | // Note that RPC serde doesn't support CUDA tensors yet, if we should |
387 | // support CUDA tensor, we need to be careful since getWritableTensorData |
388 | // converts CUDA tensor to cpu and data() might get destructed as we go |
389 | // out of scope of this loop. |
390 | auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]); |
391 | entries.push_back( |
392 | {c10::to_string(i), |
393 | writeableTensorData.data(), |
394 | writeableTensorData.sizeInBytes()}); |
395 | } |
396 | } |
397 | |
398 | std::string ; |
399 | size_t tot = 0; |
400 | for (const auto& e : entries) { |
401 | tot += e.size; |
402 | header.append(e.name) |
403 | .append(" " ) |
404 | .append(c10::to_string(e.size)) |
405 | .append("\n" ); |
406 | } |
407 | header.push_back('\n'); |
408 | |
409 | std::string out; |
410 | out.reserve(header.size() + tot); |
411 | out.append(header); |
412 | for (const auto& e : entries) { |
413 | out.append(e.data, e.size); |
414 | } |
415 | return out; |
416 | } |
417 | |
418 | std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize( |
419 | const void* data, |
420 | size_t data_size) { |
421 | auto sections = parseWireSections(data, data_size); |
422 | |
423 | std::vector<char> payload; |
424 | auto payloadIt = sections.find(kPayload); |
425 | if (payloadIt != sections.end() && payloadIt->second.second != 0) { |
426 | payload.assign( |
427 | payloadIt->second.first, |
428 | payloadIt->second.first + payloadIt->second.second); |
429 | } |
430 | |
431 | std::vector<at::Tensor> tensors; |
432 | auto metaIt = sections.find(kMeta); |
433 | if (metaIt != sections.end()) { |
434 | const auto& metaData = metaIt->second; |
435 | size_t metaDataPos = 0; |
436 | auto metaDataReadFunc = [&](char* buf, size_t n) -> size_t { |
437 | if (metaDataPos >= metaData.second || n == 0) { |
438 | return 0; |
439 | } |
440 | size_t toCopy = std::min(metaDataPos + n, metaData.second) - metaDataPos; |
441 | memcpy(buf, metaData.first + metaDataPos, toCopy); |
442 | metaDataPos += toCopy; |
443 | return toCopy; |
444 | }; |
445 | auto sectionReadFunc = [&](const std::string& ename) -> at::DataPtr { |
446 | auto it = sections.find(ename); |
447 | if (it == sections.end()) { |
448 | TORCH_CHECK(false, "Couldn't find entity " + ename); |
449 | } |
450 | const auto& idat = it->second; |
451 | auto dptr = at::getCPUAllocator()->allocate(idat.second); |
452 | if (idat.second != 0) { |
453 | memcpy(dptr.get(), idat.first, idat.second); |
454 | } |
455 | return dptr; |
456 | }; |
457 | |
458 | // No need to pass typeResolver here, as it always processes string and |
459 | // tensors only |
460 | torch::jit::Unpickler unpickler( |
461 | metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {}); |
462 | auto ival = unpickler.parse_ivalue(); |
463 | for (auto&& t : ival.toTensorList()) { |
464 | tensors.emplace_back(std::move(t)); |
465 | } |
466 | } |
467 | return {std::move(payload), std::move(tensors)}; |
468 | } |
469 | |
470 | void writeWrappedPayload( |
471 | std::vector<char>& originalPayload, |
472 | std::vector<char>& additionalPayload) { |
473 | originalPayload.insert( |
474 | originalPayload.end(), |
475 | additionalPayload.begin(), |
476 | additionalPayload.end()); |
477 | |
478 | // Add size of the additional payload |
479 | int64_t indexToWrite = originalPayload.size(); |
480 | originalPayload.resize(originalPayload.size() + sizeof(int64_t)); |
481 | const int64_t additionalPayloadSize = additionalPayload.size(); |
482 | torch::utils::THP_encodeInt64Buffer( |
483 | reinterpret_cast<uint8_t*>(originalPayload.data()) + indexToWrite, |
484 | &additionalPayloadSize, |
485 | torch::utils::THPByteOrder::THP_BIG_ENDIAN, |
486 | 1); |
487 | } |
488 | |
489 | std::vector<at::IValue> readWrappedPayload( |
490 | std::vector<char>& payload, |
491 | const rpc::Message& message) { |
492 | // Read the additional payload remove it from the payload. |
493 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
494 | int64_t additionalPayloadSize; |
495 | TORCH_INTERNAL_ASSERT(payload.size() >= sizeof(int64_t)); |
496 | size_t indexToRead = payload.size() - sizeof(int64_t); |
497 | torch::utils::THP_decodeInt64Buffer( |
498 | &additionalPayloadSize, |
499 | reinterpret_cast<uint8_t*>(payload.data()) + indexToRead, |
500 | torch::utils::THPByteOrder::THP_BIG_ENDIAN, |
501 | 1); |
502 | payload.resize(indexToRead); |
503 | |
504 | TORCH_INTERNAL_ASSERT( |
505 | // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
506 | payload.size() > additionalPayloadSize, |
507 | "Wrong payload sizes: payload.size() is " , |
508 | payload.size(), |
509 | " but additional payload size is " , |
510 | additionalPayloadSize); |
511 | auto wrappedPayloadBegin = |
512 | static_cast<const char*>(message.payload().data()) + payload.size() - |
513 | additionalPayloadSize; |
514 | std::vector<torch::Tensor> tensorTable; |
515 | IValue tuple = jit::unpickle( |
516 | wrappedPayloadBegin, |
517 | additionalPayloadSize, |
518 | *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(), |
519 | tensorTable); |
520 | std::vector<at::IValue> tupleElements = tuple.toTupleRef().elements().vec(); |
521 | payload.resize(payload.size() - additionalPayloadSize); |
522 | return tupleElements; |
523 | } |
524 | |
525 | void populateRemoteProfiledEvents( |
526 | std::vector<LegacyEvent>& profiledEvents, |
527 | const ProfilerConfig& profilingConfig, |
528 | const std::vector<std::vector<LegacyEvent>>& eventLists) { |
529 | // Gather all events into a vector |
530 | for (auto& l : eventLists) { |
531 | for (auto& e : l) { |
532 | profiledEvents.push_back(e); |
533 | } |
534 | } |
535 | // find __start_profile event |
536 | bool cudaProfilingEnabled = profilingConfig.state == ProfilerState::CUDA; |
537 | const LegacyEvent* profilerStart = nullptr; |
538 | |
539 | for (auto& e : profiledEvents) { |
540 | if (std::string(e.name()) == "__start_profile" ) { |
541 | profilerStart = &e; |
542 | break; |
543 | } |
544 | } |
545 | // We should always find __start_profile. |
546 | TORCH_CHECK( |
547 | profilerStart != nullptr, "Expected to find __start_profile event." ); |
548 | |
549 | if (cudaProfilingEnabled) { |
550 | // Deserialized events don't have the corresponding CUDA events, making it |
551 | // impossible to use cudaEventElapsedTime the receiving end. To avoid this, |
552 | // find all push/pop pairs of CUDA events and set the corresponding CUDA |
553 | // time to zero for the push event and to the elapsed time for the pop |
554 | // event, to be used later for the elapsed CUDA time computation. |
555 | std::unordered_map<at::RecordFunctionHandle, const LegacyEvent*> |
556 | startEvents; |
557 | for (auto& e : profiledEvents) { |
558 | if (e.hasCuda()) { |
559 | if (e.kind() == EventKind::PushRange) { |
560 | startEvents[e.handle()] = &e; |
561 | } |
562 | } |
563 | } |
564 | for (auto& e : profiledEvents) { |
565 | if (e.hasCuda()) { |
566 | if (e.kind() == EventKind::PopRange) { |
567 | auto it = startEvents.find(e.handle()); |
568 | if (it != startEvents.end()) { |
569 | e.setCudaUs(it->second->cudaElapsedUs(e)); |
570 | } else { |
571 | TORCH_WARN("Found a pop event without a corresponding push event" ); |
572 | e.setCudaUs(0); |
573 | } |
574 | } else { |
575 | e.setCudaUs(0); |
576 | } |
577 | } |
578 | } |
579 | } |
580 | } |
581 | |
582 | } // namespace rpc |
583 | } // namespace distributed |
584 | } // namespace torch |
585 | |