1 | #include <torch/csrc/distributed/rpc/request_callback.h> |
---|---|
2 | |
3 | #include <torch/csrc/distributed/autograd/context/container.h> |
4 | #include <torch/csrc/distributed/autograd/utils.h> |
5 | |
6 | namespace torch { |
7 | namespace distributed { |
8 | namespace rpc { |
9 | |
10 | using namespace torch::distributed::autograd; |
11 | |
12 | c10::intrusive_ptr<JitFuture> RequestCallback::operator()( |
13 | Message& request, |
14 | std::vector<c10::Stream> streams) const { |
15 | // NB: cannot clear autograd context id here because the processMessage method |
16 | // might pause waiting for all RRefs in the arguments to be confirmed by their |
17 | // owners and resumne processing in a different thread. Hence, the |
18 | // thread_local context id needs to be set and cleared in the thread that |
19 | // indeed carries out the processing logic. |
20 | return processMessage(request, std::move(streams)); |
21 | } |
22 | |
23 | } // namespace rpc |
24 | } // namespace distributed |
25 | } // namespace torch |
26 |