1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/csrc/distributed/autograd/context/container.h> |
4 | #include <torch/csrc/distributed/autograd/context/context.h> |
5 | #include <torch/csrc/distributed/autograd/engine/dist_engine.h> |
6 | #include <torch/csrc/distributed/autograd/utils.h> |
7 | #include <torch/csrc/distributed/c10d/TCPStore.hpp> |
8 | #include <torch/csrc/distributed/rpc/rref_context.h> |
9 | #include <torch/csrc/distributed/rpc/script_call.h> |
10 | #include <torch/csrc/distributed/rpc/script_remote_call.h> |
11 | #include <torch/csrc/distributed/rpc/script_resp.h> |
12 | #include <torch/csrc/distributed/rpc/utils.h> |
13 | #include <torch/csrc/jit/runtime/operator.h> |
14 | |
15 | namespace torch { |
16 | namespace distributed { |
17 | namespace rpc { |
18 | |
19 | using torch::distributed::autograd::DistAutogradContainer; |
20 | using torch::distributed::autograd::DistAutogradContext; |
21 | |
22 | DistAutogradContainer* getDistAutogradContainer(); |
23 | |
24 | class TestE2EBase : public ::testing::Test { |
25 | protected: |
26 | void SetUp() override { |
27 | // Setup distributed autograd. |
28 | autogradContainer = getDistAutogradContainer(); |
29 | |
30 | // Setup server store. |
31 | c10d::TCPStoreOptions opts{ |
32 | /* port */ 0, |
33 | /* isServer */ true, |
34 | numWorkers, |
35 | /* waitWorkers */ true, |
36 | /* timeout */ std::chrono::seconds(10)}; |
37 | |
38 | store = c10::make_intrusive<c10d::TCPStore>(serverAddress, opts); |
39 | |
40 | buildRpcAgent(); |
41 | |
42 | rpcAgentPostProcessing(); |
43 | } |
44 | |
45 | void rpcAgentPostProcessing() { |
46 | RpcAgent::setCurrentRpcAgent(rpcAgent); |
47 | std::shared_ptr<TypeResolver> typeResolver = |
48 | std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) { |
49 | // For Dict that is used for device map. |
50 | auto pos = qn.name().find("Dict" ); |
51 | if (pos != std::string::npos) { |
52 | return c10::StrongTypePtr( |
53 | nullptr, |
54 | c10::DictType::create( |
55 | c10::StringType::get(), c10::StringType::get())); |
56 | } |
57 | return c10::StrongTypePtr( |
58 | nullptr, c10::TensorType::create(at::Tensor())); |
59 | }); |
60 | rpcAgent->setTypeResolver(typeResolver); |
61 | rpcAgent->start(); |
62 | } |
63 | |
64 | void TearDown() override { |
65 | rpcAgent->join(); |
66 | rpcAgent->shutdown(); |
67 | RpcAgent::setCurrentRpcAgent(nullptr); |
68 | } |
69 | |
70 | c10::intrusive_ptr<OwnerRRef> createRemoteRRef( |
71 | at::Tensor t1, |
72 | at::Tensor t2, |
73 | std::shared_ptr<torch::jit::Operator> op) { |
74 | auto& ctx = RRefContext::getInstance(); |
75 | auto ownerRRef = ctx.createOwnerRRef(c10::TensorType::create(t1)); |
76 | // prevent this owner RRef being deleted due to other forks |
77 | ctx.addSelfAsFork(ownerRRef); |
78 | |
79 | ScriptRemoteCall scriptRemoteCall( |
80 | op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId()); |
81 | auto jitFuture = autograd::sendMessageWithAutograd( |
82 | *rpcAgent, |
83 | rpcAgent->getWorkerInfo("worker" ), |
84 | std::move(scriptRemoteCall).toMessage(), |
85 | false); |
86 | |
87 | ownerRRef->registerOwnerCreationFuture(jitFuture); |
88 | |
89 | // Builtin operators does not return py::object, and hence does not require |
90 | // GIL for destructing the potentially deleted OwerRRef. |
91 | jitFuture->addCallback( |
92 | [ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) { |
93 | callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId); |
94 | }); |
95 | return ownerRRef; |
96 | } |
97 | |
98 | at::Tensor remoteAdd( |
99 | at::Tensor t1, |
100 | at::Tensor t2, |
101 | std::shared_ptr<torch::jit::Operator> op) { |
102 | ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1}); |
103 | |
104 | // Send the RPC and return result. |
105 | auto response = autograd::sendMessageWithAutograd( |
106 | *rpcAgent, |
107 | rpcAgent->getWorkerInfo("worker" ), |
108 | std::move(scriptCall).toMessage()); |
109 | response->waitAndThrow(); |
110 | |
111 | MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP; |
112 | auto wrappedResponse = deserializeResponse( |
113 | std::move(*response->value().toCustomClass<Message>()), messageType); |
114 | return static_cast<ScriptResp&>(*wrappedResponse).value().toTensor(); |
115 | } |
116 | |
117 | virtual void buildRpcAgent() = 0; |
118 | |
119 | class AutogradContextGuard { |
120 | public: |
121 | explicit AutogradContextGuard() |
122 | : context(DistAutogradContainer::getInstance().newContext()) {} |
123 | |
124 | ~AutogradContextGuard() { |
125 | DistAutogradContainer::getInstance().releaseContext(context->contextId()); |
126 | } |
127 | |
128 | private: |
129 | std::shared_ptr<DistAutogradContext> context; |
130 | }; |
131 | |
132 | void runTrainingLoop() { |
133 | auto options = at::TensorOptions().requires_grad(true); |
134 | auto t1 = torch::ones({3, 3}, options); |
135 | auto t2 = torch::ones({3, 3}, options); |
136 | |
137 | c10::OperatorName full_name("aten::add" , "Tensor" ); |
138 | auto matchedOp = torch::jit::findOperatorFor(full_name); |
139 | ASSERT_TRUE(matchedOp); |
140 | |
141 | for (size_t i = 0; i < numIters; i++) { |
142 | // Create the autograd context guard. |
143 | AutogradContextGuard guard; |
144 | |
145 | // Multiple RPCs within one autograd context for the forward pass. |
146 | auto result = remoteAdd(t1, t2, matchedOp); |
147 | for (size_t j = 0; j < 5; j++) { |
148 | result = remoteAdd(t1, result, matchedOp); |
149 | } |
150 | |
151 | auto rref = createRemoteRRef(t1, result, matchedOp); |
152 | result = rref->getValue().toTensor(); |
153 | |
154 | // Run backward pass now. |
155 | autograd::DistEngine::getInstance().execute( |
156 | DistAutogradContainer::currentContextId(), |
157 | {torch::sum(result)}, |
158 | /* retainGraph */ false); |
159 | } |
160 | } |
161 | |
162 | DistAutogradContainer* autogradContainer; |
163 | std::shared_ptr<RpcAgent> rpcAgent; |
164 | static const size_t numIters; |
165 | static const size_t numWorkers; |
166 | c10::intrusive_ptr<c10d::Store> store; |
167 | static const char* serverAddress; |
168 | }; |
169 | |
170 | } // namespace rpc |
171 | } // namespace distributed |
172 | } // namespace torch |
173 | |