1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/distributed/rpc/utils.h> |
5 | #include <torch/torch.h> |
6 | |
7 | #include <memory> |
8 | #include <string> |
9 | #include <vector> |
10 | |
11 | using ::testing::IsSubstring; |
12 | |
13 | TEST(WireSerialize, Base) { |
14 | auto run = [](const std::string& payload, |
15 | const std::vector<at::Tensor>& tensors) { |
16 | std::string serialized; |
17 | { |
18 | std::vector<char> mpayload(payload.begin(), payload.end()); |
19 | std::vector<at::Tensor> mtensors = tensors; |
20 | serialized = torch::distributed::rpc::wireSerialize( |
21 | std::move(mpayload), std::move(mtensors)); |
22 | } |
23 | auto deser = torch::distributed::rpc::wireDeserialize( |
24 | serialized.data(), serialized.size()); |
25 | EXPECT_EQ(payload.size(), deser.first.size()); |
26 | EXPECT_EQ(tensors.size(), deser.second.size()); |
27 | if (payload.size() > 0) { |
28 | EXPECT_TRUE( |
29 | memcmp(deser.first.data(), payload.data(), payload.size()) == 0); |
30 | } |
31 | for (const auto i : c10::irange(tensors.size())) { |
32 | EXPECT_TRUE(torch::equal(tensors[i], deser.second[i])); |
33 | } |
34 | }; |
35 | run("" , {}); |
36 | run("hi" , {}); |
37 | run("" , {torch::randn({5, 5})}); |
38 | run("hi" , {torch::randn({5, 5})}); |
39 | run("more" , {torch::randn({5, 5}), torch::rand({10, 10})}); |
40 | } |
41 | |
42 | TEST(WireSerialize, RecopySparseTensors) { |
43 | // Take a 1K row of a 1M tensors, and make sure we don't send across 1M rows. |
44 | constexpr size_t k1K = 1024; |
45 | at::Tensor main = torch::randn({k1K, k1K}); |
46 | at::Tensor tiny = main.select(0, 2); // Select a row in the middle |
47 | EXPECT_EQ(tiny.numel(), k1K); |
48 | EXPECT_EQ(tiny.storage().nbytes() / tiny.dtype().itemsize(), k1K * k1K); |
49 | auto ser = torch::distributed::rpc::wireSerialize({}, {tiny}); |
50 | auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size()); |
51 | EXPECT_TRUE(torch::equal(tiny, deser.second[0])); |
52 | EXPECT_LT(ser.size(), (tiny.element_size() * k1K) + k1K); |
53 | } |
54 | |
55 | TEST(WireSerialize, CloneSparseTensors) { |
56 | constexpr size_t k1K = 1024; |
57 | at::Tensor big = torch::randn({k1K, k1K}); |
58 | auto v1 = torch::distributed::rpc::cloneSparseTensors({big}); |
59 | EXPECT_EQ(v1.get(0).storage(), big.storage()); // Not cloned |
60 | |
61 | at::Tensor tiny = big.select(0, 2); // Select a row in the middle |
62 | auto v2 = torch::distributed::rpc::cloneSparseTensors({tiny}); |
63 | EXPECT_NE(&v2.get(0).storage(), &tiny.storage()); // Cloned. |
64 | EXPECT_TRUE(torch::equal(v2.get(0), tiny)); |
65 | |
66 | at::Tensor sparse = at::empty({2, 3}, at::dtype<float>().layout(at::kSparse)); |
67 | auto v3 = torch::distributed::rpc::cloneSparseTensors({sparse}); |
68 | // There is no storage() to compare, but at least confirm equality. |
69 | EXPECT_TRUE(v3.get(0).is_same(sparse)); |
70 | } |
71 | |
72 | TEST(WireSerialize, Errors) { |
73 | auto checkMessage = [](auto&& f, const char* msg) { |
74 | try { |
75 | f(); |
76 | FAIL(); |
77 | } catch (const std::exception& e) { |
78 | EXPECT_PRED_FORMAT2(IsSubstring, msg, e.what()); |
79 | } catch (...) { |
80 | FAIL(); |
81 | } |
82 | }; |
83 | checkMessage( |
84 | []() { (void)torch::distributed::rpc::wireDeserialize("" , 0); }, |
85 | "failed parse" ); |
86 | checkMessage( |
87 | []() { (void)torch::distributed::rpc::wireDeserialize(" " , 1); }, |
88 | "failed parse" ); |
89 | auto serialized = |
90 | torch::distributed::rpc::wireSerialize({}, {torch::randn({5, 5})}); |
91 | checkMessage( |
92 | [&]() { |
93 | (void)torch::distributed::rpc::wireDeserialize( |
94 | serialized.data(), serialized.size() / 2); |
95 | }, |
96 | "failed bounds" ); |
97 | } |
98 | |
99 | // Enable this once JIT Pickler supports sparse tensors. |
100 | TEST(WireSerialize, DISABLED_Sparse) { |
101 | at::Tensor main = at::empty({2, 3}, at::dtype<float>().layout(at::kSparse)); |
102 | auto ser = torch::distributed::rpc::wireSerialize({}, {main.to(at::kSparse)}); |
103 | auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size()); |
104 | EXPECT_TRUE(torch::equal(main, deser.second[0])); |
105 | } |
106 | |