1#include <gtest/gtest.h>
2
3#include <torch/csrc/distributed/autograd/context/container.h>
4#include <torch/csrc/distributed/autograd/context/context.h>
5#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
6#include <torch/csrc/distributed/autograd/utils.h>
7#include <torch/csrc/distributed/c10d/TCPStore.hpp>
8#include <torch/csrc/distributed/rpc/rref_context.h>
9#include <torch/csrc/distributed/rpc/script_call.h>
10#include <torch/csrc/distributed/rpc/script_remote_call.h>
11#include <torch/csrc/distributed/rpc/script_resp.h>
12#include <torch/csrc/distributed/rpc/utils.h>
13#include <torch/csrc/jit/runtime/operator.h>
14
15namespace torch {
16namespace distributed {
17namespace rpc {
18
19using torch::distributed::autograd::DistAutogradContainer;
20using torch::distributed::autograd::DistAutogradContext;
21
22DistAutogradContainer* getDistAutogradContainer();
23
24class TestE2EBase : public ::testing::Test {
25 protected:
26 void SetUp() override {
27 // Setup distributed autograd.
28 autogradContainer = getDistAutogradContainer();
29
30 // Setup server store.
31 c10d::TCPStoreOptions opts{
32 /* port */ 0,
33 /* isServer */ true,
34 numWorkers,
35 /* waitWorkers */ true,
36 /* timeout */ std::chrono::seconds(10)};
37
38 store = c10::make_intrusive<c10d::TCPStore>(serverAddress, opts);
39
40 buildRpcAgent();
41
42 rpcAgentPostProcessing();
43 }
44
45 void rpcAgentPostProcessing() {
46 RpcAgent::setCurrentRpcAgent(rpcAgent);
47 std::shared_ptr<TypeResolver> typeResolver =
48 std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
49 // For Dict that is used for device map.
50 auto pos = qn.name().find("Dict");
51 if (pos != std::string::npos) {
52 return c10::StrongTypePtr(
53 nullptr,
54 c10::DictType::create(
55 c10::StringType::get(), c10::StringType::get()));
56 }
57 return c10::StrongTypePtr(
58 nullptr, c10::TensorType::create(at::Tensor()));
59 });
60 rpcAgent->setTypeResolver(typeResolver);
61 rpcAgent->start();
62 }
63
64 void TearDown() override {
65 rpcAgent->join();
66 rpcAgent->shutdown();
67 RpcAgent::setCurrentRpcAgent(nullptr);
68 }
69
70 c10::intrusive_ptr<OwnerRRef> createRemoteRRef(
71 at::Tensor t1,
72 at::Tensor t2,
73 std::shared_ptr<torch::jit::Operator> op) {
74 auto& ctx = RRefContext::getInstance();
75 auto ownerRRef = ctx.createOwnerRRef(c10::TensorType::create(t1));
76 // prevent this owner RRef being deleted due to other forks
77 ctx.addSelfAsFork(ownerRRef);
78
79 ScriptRemoteCall scriptRemoteCall(
80 op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId());
81 auto jitFuture = autograd::sendMessageWithAutograd(
82 *rpcAgent,
83 rpcAgent->getWorkerInfo("worker"),
84 std::move(scriptRemoteCall).toMessage(),
85 false);
86
87 ownerRRef->registerOwnerCreationFuture(jitFuture);
88
89 // Builtin operators does not return py::object, and hence does not require
90 // GIL for destructing the potentially deleted OwerRRef.
91 jitFuture->addCallback(
92 [ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) {
93 callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId);
94 });
95 return ownerRRef;
96 }
97
98 at::Tensor remoteAdd(
99 at::Tensor t1,
100 at::Tensor t2,
101 std::shared_ptr<torch::jit::Operator> op) {
102 ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1});
103
104 // Send the RPC and return result.
105 auto response = autograd::sendMessageWithAutograd(
106 *rpcAgent,
107 rpcAgent->getWorkerInfo("worker"),
108 std::move(scriptCall).toMessage());
109 response->waitAndThrow();
110
111 MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP;
112 auto wrappedResponse = deserializeResponse(
113 std::move(*response->value().toCustomClass<Message>()), messageType);
114 return static_cast<ScriptResp&>(*wrappedResponse).value().toTensor();
115 }
116
117 virtual void buildRpcAgent() = 0;
118
119 class AutogradContextGuard {
120 public:
121 explicit AutogradContextGuard()
122 : context(DistAutogradContainer::getInstance().newContext()) {}
123
124 ~AutogradContextGuard() {
125 DistAutogradContainer::getInstance().releaseContext(context->contextId());
126 }
127
128 private:
129 std::shared_ptr<DistAutogradContext> context;
130 };
131
132 void runTrainingLoop() {
133 auto options = at::TensorOptions().requires_grad(true);
134 auto t1 = torch::ones({3, 3}, options);
135 auto t2 = torch::ones({3, 3}, options);
136
137 c10::OperatorName full_name("aten::add", "Tensor");
138 auto matchedOp = torch::jit::findOperatorFor(full_name);
139 ASSERT_TRUE(matchedOp);
140
141 for (size_t i = 0; i < numIters; i++) {
142 // Create the autograd context guard.
143 AutogradContextGuard guard;
144
145 // Multiple RPCs within one autograd context for the forward pass.
146 auto result = remoteAdd(t1, t2, matchedOp);
147 for (size_t j = 0; j < 5; j++) {
148 result = remoteAdd(t1, result, matchedOp);
149 }
150
151 auto rref = createRemoteRRef(t1, result, matchedOp);
152 result = rref->getValue().toTensor();
153
154 // Run backward pass now.
155 autograd::DistEngine::getInstance().execute(
156 DistAutogradContainer::currentContextId(),
157 {torch::sum(result)},
158 /* retainGraph */ false);
159 }
160 }
161
162 DistAutogradContainer* autogradContainer;
163 std::shared_ptr<RpcAgent> rpcAgent;
164 static const size_t numIters;
165 static const size_t numWorkers;
166 c10::intrusive_ptr<c10d::Store> store;
167 static const char* serverAddress;
168};
169
170} // namespace rpc
171} // namespace distributed
172} // namespace torch
173