1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/csrc/distributed/rpc/utils.h>
5#include <torch/torch.h>
6
7#include <memory>
8#include <string>
9#include <vector>
10
11using ::testing::IsSubstring;
12
13TEST(WireSerialize, Base) {
14 auto run = [](const std::string& payload,
15 const std::vector<at::Tensor>& tensors) {
16 std::string serialized;
17 {
18 std::vector<char> mpayload(payload.begin(), payload.end());
19 std::vector<at::Tensor> mtensors = tensors;
20 serialized = torch::distributed::rpc::wireSerialize(
21 std::move(mpayload), std::move(mtensors));
22 }
23 auto deser = torch::distributed::rpc::wireDeserialize(
24 serialized.data(), serialized.size());
25 EXPECT_EQ(payload.size(), deser.first.size());
26 EXPECT_EQ(tensors.size(), deser.second.size());
27 if (payload.size() > 0) {
28 EXPECT_TRUE(
29 memcmp(deser.first.data(), payload.data(), payload.size()) == 0);
30 }
31 for (const auto i : c10::irange(tensors.size())) {
32 EXPECT_TRUE(torch::equal(tensors[i], deser.second[i]));
33 }
34 };
35 run("", {});
36 run("hi", {});
37 run("", {torch::randn({5, 5})});
38 run("hi", {torch::randn({5, 5})});
39 run("more", {torch::randn({5, 5}), torch::rand({10, 10})});
40}
41
42TEST(WireSerialize, RecopySparseTensors) {
43 // Take a 1K row of a 1M tensors, and make sure we don't send across 1M rows.
44 constexpr size_t k1K = 1024;
45 at::Tensor main = torch::randn({k1K, k1K});
46 at::Tensor tiny = main.select(0, 2); // Select a row in the middle
47 EXPECT_EQ(tiny.numel(), k1K);
48 EXPECT_EQ(tiny.storage().nbytes() / tiny.dtype().itemsize(), k1K * k1K);
49 auto ser = torch::distributed::rpc::wireSerialize({}, {tiny});
50 auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size());
51 EXPECT_TRUE(torch::equal(tiny, deser.second[0]));
52 EXPECT_LT(ser.size(), (tiny.element_size() * k1K) + k1K);
53}
54
55TEST(WireSerialize, CloneSparseTensors) {
56 constexpr size_t k1K = 1024;
57 at::Tensor big = torch::randn({k1K, k1K});
58 auto v1 = torch::distributed::rpc::cloneSparseTensors({big});
59 EXPECT_EQ(v1.get(0).storage(), big.storage()); // Not cloned
60
61 at::Tensor tiny = big.select(0, 2); // Select a row in the middle
62 auto v2 = torch::distributed::rpc::cloneSparseTensors({tiny});
63 EXPECT_NE(&v2.get(0).storage(), &tiny.storage()); // Cloned.
64 EXPECT_TRUE(torch::equal(v2.get(0), tiny));
65
66 at::Tensor sparse = at::empty({2, 3}, at::dtype<float>().layout(at::kSparse));
67 auto v3 = torch::distributed::rpc::cloneSparseTensors({sparse});
68 // There is no storage() to compare, but at least confirm equality.
69 EXPECT_TRUE(v3.get(0).is_same(sparse));
70}
71
72TEST(WireSerialize, Errors) {
73 auto checkMessage = [](auto&& f, const char* msg) {
74 try {
75 f();
76 FAIL();
77 } catch (const std::exception& e) {
78 EXPECT_PRED_FORMAT2(IsSubstring, msg, e.what());
79 } catch (...) {
80 FAIL();
81 }
82 };
83 checkMessage(
84 []() { (void)torch::distributed::rpc::wireDeserialize("", 0); },
85 "failed parse");
86 checkMessage(
87 []() { (void)torch::distributed::rpc::wireDeserialize(" ", 1); },
88 "failed parse");
89 auto serialized =
90 torch::distributed::rpc::wireSerialize({}, {torch::randn({5, 5})});
91 checkMessage(
92 [&]() {
93 (void)torch::distributed::rpc::wireDeserialize(
94 serialized.data(), serialized.size() / 2);
95 },
96 "failed bounds");
97}
98
99// Enable this once JIT Pickler supports sparse tensors.
100TEST(WireSerialize, DISABLED_Sparse) {
101 at::Tensor main = at::empty({2, 3}, at::dtype<float>().layout(at::kSparse));
102 auto ser = torch::distributed::rpc::wireSerialize({}, {main.to(at::kSparse)});
103 auto deser = torch::distributed::rpc::wireDeserialize(ser.data(), ser.size());
104 EXPECT_TRUE(torch::equal(main, deser.second[0]));
105}
106