1 | #include <gtest/gtest.h> |
2 | |
3 | #include "e2e_test_base.h" |
4 | |
5 | #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> |
6 | #include <torch/csrc/distributed/rpc/request_callback_no_python.h> |
7 | #include <torch/csrc/distributed/rpc/tensorpipe_agent.h> |
8 | #include <torch/torch.h> |
9 | |
10 | namespace torch { |
11 | namespace distributed { |
12 | namespace rpc { |
13 | |
14 | #ifdef USE_TENSORPIPE |
15 | |
16 | class TestE2ETensorPipe : public TestE2EBase { |
17 | protected: |
18 | void buildRpcAgent() override { |
19 | auto options = c10d::ProcessGroupGloo::Options::create(); |
20 | options->devices.push_back( |
21 | ::c10d::ProcessGroupGloo::createDeviceForHostname(serverAddress)); |
22 | float rpcTimeout = 30; |
23 | |
24 | TensorPipeRpcBackendOptions opts( |
25 | /*numWorkerThreads=*/std::max(16U, std::thread::hardware_concurrency()), |
26 | /*transports=*/nullopt, |
27 | /*channels=*/nullopt, |
28 | /*rpc_timeout=*/rpcTimeout, |
29 | /*init_method=*/"unused" ); |
30 | |
31 | rpcAgent = std::make_shared<TensorPipeAgent>( |
32 | store, |
33 | "worker" , |
34 | 0, |
35 | numWorkers, |
36 | opts, |
37 | std::unordered_map<std::string, DeviceMap>{}, |
38 | std::vector<c10::Device>{}, |
39 | std::make_unique<RequestCallbackNoPython>()); |
40 | } |
41 | }; |
42 | |
43 | // End to end training loop test in C++ so that we can run LSAN on this test to |
44 | // catch memory leaks. Enabling LSAN with python multiprocessing has been |
45 | // challenging and we don't have a good solution yet. |
46 | TEST_F(TestE2ETensorPipe, TestTrainingLoop) { |
47 | runTrainingLoop(); |
48 | // Ensure the tensorpipe internal state is cleared up. |
49 | auto tensorpipeAgent = std::static_pointer_cast<TensorPipeAgent>(rpcAgent); |
50 | |
51 | // Shutdown RPC agent for all RPCs to clean up. |
52 | tensorpipeAgent->join(); |
53 | tensorpipeAgent->shutdown(); |
54 | ASSERT_EQ(0, tensorpipeAgent->numPendingResponses()); |
55 | ASSERT_EQ(0, tensorpipeAgent->timeoutMapSize()); |
56 | ASSERT_EQ(0, tensorpipeAgent->messageIdToTimeoutMapSize()); |
57 | } |
58 | |
59 | #endif |
60 | |
61 | } // namespace rpc |
62 | } // namespace distributed |
63 | } // namespace torch |
64 | |