1#include <gtest/gtest.h>
2
3#include "e2e_test_base.h"
4
5#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
6#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
7#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
8#include <torch/torch.h>
9
10namespace torch {
11namespace distributed {
12namespace rpc {
13
14#ifdef USE_TENSORPIPE
15
16class TestE2ETensorPipe : public TestE2EBase {
17 protected:
18 void buildRpcAgent() override {
19 auto options = c10d::ProcessGroupGloo::Options::create();
20 options->devices.push_back(
21 ::c10d::ProcessGroupGloo::createDeviceForHostname(serverAddress));
22 float rpcTimeout = 30;
23
24 TensorPipeRpcBackendOptions opts(
25 /*numWorkerThreads=*/std::max(16U, std::thread::hardware_concurrency()),
26 /*transports=*/nullopt,
27 /*channels=*/nullopt,
28 /*rpc_timeout=*/rpcTimeout,
29 /*init_method=*/"unused");
30
31 rpcAgent = std::make_shared<TensorPipeAgent>(
32 store,
33 "worker",
34 0,
35 numWorkers,
36 opts,
37 std::unordered_map<std::string, DeviceMap>{},
38 std::vector<c10::Device>{},
39 std::make_unique<RequestCallbackNoPython>());
40 }
41};
42
43// End to end training loop test in C++ so that we can run LSAN on this test to
44// catch memory leaks. Enabling LSAN with python multiprocessing has been
45// challenging and we don't have a good solution yet.
46TEST_F(TestE2ETensorPipe, TestTrainingLoop) {
47 runTrainingLoop();
48 // Ensure the tensorpipe internal state is cleared up.
49 auto tensorpipeAgent = std::static_pointer_cast<TensorPipeAgent>(rpcAgent);
50
51 // Shutdown RPC agent for all RPCs to clean up.
52 tensorpipeAgent->join();
53 tensorpipeAgent->shutdown();
54 ASSERT_EQ(0, tensorpipeAgent->numPendingResponses());
55 ASSERT_EQ(0, tensorpipeAgent->timeoutMapSize());
56 ASSERT_EQ(0, tensorpipeAgent->messageIdToTimeoutMapSize());
57}
58
59#endif
60
61} // namespace rpc
62} // namespace distributed
63} // namespace torch
64