1#include <memory>
2
3#include <gtest/gtest.h>
4
5#include <ATen/ATen.h>
6#include <torch/csrc/distributed/autograd/context/container.h>
7#include <torch/csrc/distributed/autograd/context/context.h>
8#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
9#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
10#include <torch/csrc/distributed/autograd/utils.h>
11#include <torch/torch.h>
12
13namespace torch {
14namespace distributed {
15namespace autograd {
16
17class DistAutogradTest : public ::testing::Test {
18 protected:
19 static void SetUpTestCase() {
20 autogradContainer_ = &DistAutogradContainer::init(0);
21 }
22
23 void TearDown() override {
24 autogradContainer_->releaseContext(
25 autogradContainer_->currentContext()->contextId());
26 }
27
28 static DistAutogradContainer* autogradContainer_;
29};
30
31DistAutogradContainer* DistAutogradTest::autogradContainer_ = nullptr;
32
33TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) {
34 auto options = at::TensorOptions().requires_grad(true);
35 auto in1 = torch::ones({3, 3}, options);
36 auto in2 = torch::ones({3, 3}, options);
37
38 autogradContainer_->newContext();
39 auto autogradContext = autogradContainer_->currentContext();
40 // Attach the send autograd function to tensors.
41 std::vector<torch::Tensor> tensors = {in1, in2};
42 rpc::worker_id_t worker_id = 1;
43 addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors);
44 autogradContext->addKnownWorkerId(worker_id);
45 auto send_function = autogradContext->sendFunctions()[1];
46
47 // ensure that the worker_ids are recorded
48 auto knownWorkerIds = autogradContext->getKnownWorkerIds();
49 ASSERT_TRUE(knownWorkerIds.find(worker_id) != knownWorkerIds.end());
50 ASSERT_EQ(knownWorkerIds.size(), 1);
51
52 // This should fail since the SendRpcBackward function shouldn't receive any
53 // inputs grad.
54 EXPECT_THROW(send_function->apply({in1, in2}), c10::Error);
55
56 // This should fail since the SendRpcBackward function encounters an undefined
57 // grad.
58 send_function->setGrads({in1, torch::autograd::Variable()});
59 EXPECT_THROW(send_function->apply({}), c10::Error);
60}
61
62TEST_F(DistAutogradTest, TestInitializedContextCleanup) {
63 autogradContainer_->newContext();
64 auto contextId = autogradContainer_->currentContext()->contextId();
65 auto& engine = DistEngine::getInstance();
66 ASSERT_EQ(0, engine.numBackwardPasses());
67
68 // Build autograd graph
69 auto x = torch::randn({2, 2}, torch::requires_grad());
70 auto y = torch::randn({2, 2}, torch::requires_grad());
71 auto z = (x * x + y * y).sum();
72 ASSERT_NE(nullptr, z.grad_fn());
73
74 // Execute engine.
75 engine.execute(contextId, {z}, /* retainGraph */ false);
76
77 // Validate appropriate cleanup.
78 ASSERT_EQ(0, engine.numBackwardPasses());
79}
80
81TEST_F(DistAutogradTest, TestInitializedContextCleanupSendFunction) {
82 autogradContainer_->newContext();
83 auto context = autogradContainer_->currentContext();
84 auto& engine = DistEngine::getInstance();
85 ASSERT_EQ(0, engine.numBackwardPasses());
86
87 // Attach send function.
88 auto options = at::TensorOptions().requires_grad(true);
89 auto t = torch::ones({1}, options);
90 auto tensors = std::vector<torch::Tensor>{t};
91 addSendRpcBackward(
92 context, AutogradMetadata(context->contextId(), 0), tensors);
93
94 auto sendFunction = context->retrieveSendFunction(0);
95 sendFunction->setGrads({t});
96
97 // Execute engine.
98 engine
99 .executeSendFunctionAsync(context, sendFunction, /*retrainGraph*/ false)
100 ->wait();
101
102 // Validate appropriate cleanup.
103 ASSERT_EQ(0, engine.numBackwardPasses());
104}
105
106} // namespace autograd
107} // namespace distributed
108} // namespace torch
109