1 | #include <memory> |
2 | |
3 | #include <gtest/gtest.h> |
4 | |
5 | #include <ATen/ATen.h> |
6 | #include <torch/csrc/distributed/autograd/context/container.h> |
7 | #include <torch/csrc/distributed/autograd/context/context.h> |
8 | #include <torch/csrc/distributed/autograd/engine/dist_engine.h> |
9 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h> |
10 | #include <torch/csrc/distributed/autograd/utils.h> |
11 | #include <torch/torch.h> |
12 | |
13 | namespace torch { |
14 | namespace distributed { |
15 | namespace autograd { |
16 | |
17 | class DistAutogradTest : public ::testing::Test { |
18 | protected: |
19 | static void SetUpTestCase() { |
20 | autogradContainer_ = &DistAutogradContainer::init(0); |
21 | } |
22 | |
23 | void TearDown() override { |
24 | autogradContainer_->releaseContext( |
25 | autogradContainer_->currentContext()->contextId()); |
26 | } |
27 | |
28 | static DistAutogradContainer* autogradContainer_; |
29 | }; |
30 | |
31 | DistAutogradContainer* DistAutogradTest::autogradContainer_ = nullptr; |
32 | |
33 | TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) { |
34 | auto options = at::TensorOptions().requires_grad(true); |
35 | auto in1 = torch::ones({3, 3}, options); |
36 | auto in2 = torch::ones({3, 3}, options); |
37 | |
38 | autogradContainer_->newContext(); |
39 | auto autogradContext = autogradContainer_->currentContext(); |
40 | // Attach the send autograd function to tensors. |
41 | std::vector<torch::Tensor> tensors = {in1, in2}; |
42 | rpc::worker_id_t worker_id = 1; |
43 | addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors); |
44 | autogradContext->addKnownWorkerId(worker_id); |
45 | auto send_function = autogradContext->sendFunctions()[1]; |
46 | |
47 | // ensure that the worker_ids are recorded |
48 | auto knownWorkerIds = autogradContext->getKnownWorkerIds(); |
49 | ASSERT_TRUE(knownWorkerIds.find(worker_id) != knownWorkerIds.end()); |
50 | ASSERT_EQ(knownWorkerIds.size(), 1); |
51 | |
52 | // This should fail since the SendRpcBackward function shouldn't receive any |
53 | // inputs grad. |
54 | EXPECT_THROW(send_function->apply({in1, in2}), c10::Error); |
55 | |
56 | // This should fail since the SendRpcBackward function encounters an undefined |
57 | // grad. |
58 | send_function->setGrads({in1, torch::autograd::Variable()}); |
59 | EXPECT_THROW(send_function->apply({}), c10::Error); |
60 | } |
61 | |
62 | TEST_F(DistAutogradTest, TestInitializedContextCleanup) { |
63 | autogradContainer_->newContext(); |
64 | auto contextId = autogradContainer_->currentContext()->contextId(); |
65 | auto& engine = DistEngine::getInstance(); |
66 | ASSERT_EQ(0, engine.numBackwardPasses()); |
67 | |
68 | // Build autograd graph |
69 | auto x = torch::randn({2, 2}, torch::requires_grad()); |
70 | auto y = torch::randn({2, 2}, torch::requires_grad()); |
71 | auto z = (x * x + y * y).sum(); |
72 | ASSERT_NE(nullptr, z.grad_fn()); |
73 | |
74 | // Execute engine. |
75 | engine.execute(contextId, {z}, /* retainGraph */ false); |
76 | |
77 | // Validate appropriate cleanup. |
78 | ASSERT_EQ(0, engine.numBackwardPasses()); |
79 | } |
80 | |
81 | TEST_F(DistAutogradTest, TestInitializedContextCleanupSendFunction) { |
82 | autogradContainer_->newContext(); |
83 | auto context = autogradContainer_->currentContext(); |
84 | auto& engine = DistEngine::getInstance(); |
85 | ASSERT_EQ(0, engine.numBackwardPasses()); |
86 | |
87 | // Attach send function. |
88 | auto options = at::TensorOptions().requires_grad(true); |
89 | auto t = torch::ones({1}, options); |
90 | auto tensors = std::vector<torch::Tensor>{t}; |
91 | addSendRpcBackward( |
92 | context, AutogradMetadata(context->contextId(), 0), tensors); |
93 | |
94 | auto sendFunction = context->retrieveSendFunction(0); |
95 | sendFunction->setGrads({t}); |
96 | |
97 | // Execute engine. |
98 | engine |
99 | .executeSendFunctionAsync(context, sendFunction, /*retrainGraph*/ false) |
100 | ->wait(); |
101 | |
102 | // Validate appropriate cleanup. |
103 | ASSERT_EQ(0, engine.numBackwardPasses()); |
104 | } |
105 | |
106 | } // namespace autograd |
107 | } // namespace distributed |
108 | } // namespace torch |
109 | |