1 | #include <cstdio> |
2 | #include <string> |
3 | #include <array> |
4 | |
5 | #include <gtest/gtest.h> |
6 | |
7 | #include "caffe2/serialize/inline_container.h" |
8 | #include "c10/util/irange.h" |
9 | |
10 | namespace caffe2 { |
11 | namespace serialize { |
12 | namespace { |
13 | |
14 | TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { |
15 | int64_t kFieldAlignment = 64L; |
16 | |
17 | std::ostringstream oss; |
18 | // write records through writers |
19 | PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
20 | oss.write(static_cast<const char*>(b), n); |
21 | return oss ? n : 0; |
22 | }); |
23 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
24 | std::array<char, 127> data1; |
25 | |
26 | for (auto i: c10::irange( data1.size())) { |
27 | data1[i] = data1.size() - i; |
28 | } |
29 | writer.writeRecord("key1" , data1.data(), data1.size()); |
30 | |
31 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
32 | std::array<char, 64> data2; |
33 | for (auto i: c10::irange(data2.size())) { |
34 | data2[i] = data2.size() - i; |
35 | } |
36 | writer.writeRecord("key2" , data2.data(), data2.size()); |
37 | |
38 | const std::unordered_set<std::string>& written_records = |
39 | writer.getAllWrittenRecords(); |
40 | ASSERT_EQ(written_records.size(), 2); |
41 | ASSERT_EQ(written_records.count("key1" ), 1); |
42 | ASSERT_EQ(written_records.count("key2" ), 1); |
43 | |
44 | writer.writeEndOfFile(); |
45 | |
46 | std::string the_file = oss.str(); |
47 | std::ofstream foo("output.zip" ); |
48 | foo.write(the_file.c_str(), the_file.size()); |
49 | foo.close(); |
50 | |
51 | std::istringstream iss(the_file); |
52 | |
53 | // read records through readers |
54 | PyTorchStreamReader reader(&iss); |
55 | ASSERT_TRUE(reader.hasRecord("key1" )); |
56 | ASSERT_TRUE(reader.hasRecord("key2" )); |
57 | ASSERT_FALSE(reader.hasRecord("key2000" )); |
58 | at::DataPtr data_ptr; |
59 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
60 | int64_t size; |
61 | std::tie(data_ptr, size) = reader.getRecord("key1" ); |
62 | size_t off1 = reader.getRecordOffset("key1" ); |
63 | ASSERT_EQ(size, data1.size()); |
64 | ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0); |
65 | ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0); |
66 | ASSERT_EQ(off1 % kFieldAlignment, 0); |
67 | |
68 | std::tie(data_ptr, size) = reader.getRecord("key2" ); |
69 | size_t off2 = reader.getRecordOffset("key2" ); |
70 | ASSERT_EQ(off2 % kFieldAlignment, 0); |
71 | |
72 | ASSERT_EQ(size, data2.size()); |
73 | ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0); |
74 | ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0); |
75 | } |
76 | |
77 | TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) { |
78 | std::ostringstream oss; |
79 | // write records through writers |
80 | PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
81 | oss.write(static_cast<const char*>(b), n); |
82 | return oss ? n : 0; |
83 | }); |
84 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
85 | std::array<char, 127> data1; |
86 | |
87 | for (auto i: c10::irange(data1.size())) { |
88 | data1[i] = data1.size() - i; |
89 | } |
90 | writer.writeRecord("key1" , data1.data(), data1.size()); |
91 | |
92 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
93 | std::array<char, 64> data2; |
94 | for (auto i: c10::irange(data2.size())) { |
95 | data2[i] = data2.size() - i; |
96 | } |
97 | writer.writeRecord("key2" , data2.data(), data2.size()); |
98 | |
99 | const std::unordered_set<std::string>& written_records = |
100 | writer.getAllWrittenRecords(); |
101 | ASSERT_EQ(written_records.size(), 2); |
102 | ASSERT_EQ(written_records.count("key1" ), 1); |
103 | ASSERT_EQ(written_records.count("key2" ), 1); |
104 | |
105 | writer.writeEndOfFile(); |
106 | |
107 | std::string the_file = oss.str(); |
108 | std::ofstream foo("output2.zip" ); |
109 | foo.write(the_file.c_str(), the_file.size()); |
110 | foo.close(); |
111 | |
112 | std::istringstream iss(the_file); |
113 | |
114 | // read records through readers |
115 | PyTorchStreamReader reader(&iss); |
116 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
117 | EXPECT_THROW(reader.getRecord("key3" ), c10::Error); |
118 | |
119 | // Reader should still work after throwing |
120 | EXPECT_TRUE(reader.hasRecord("key1" )); |
121 | } |
122 | |
123 | TEST(PytorchStreamWriterAndReader, SkipDebugRecords) { |
124 | std::ostringstream oss; |
125 | PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { |
126 | oss.write(static_cast<const char*>(b), n); |
127 | return oss ? n : 0; |
128 | }); |
129 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
130 | std::array<char, 127> data1; |
131 | |
132 | for (auto i: c10::irange(data1.size())) { |
133 | data1[i] = data1.size() - i; |
134 | } |
135 | writer.writeRecord("key1.debug_pkl" , data1.data(), data1.size()); |
136 | |
137 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) |
138 | std::array<char, 64> data2; |
139 | for (auto i: c10::irange(data2.size())) { |
140 | data2[i] = data2.size() - i; |
141 | } |
142 | writer.writeRecord("key2.debug_pkl" , data2.data(), data2.size()); |
143 | |
144 | const std::unordered_set<std::string>& written_records = |
145 | writer.getAllWrittenRecords(); |
146 | ASSERT_EQ(written_records.size(), 2); |
147 | ASSERT_EQ(written_records.count("key1.debug_pkl" ), 1); |
148 | ASSERT_EQ(written_records.count("key2.debug_pkl" ), 1); |
149 | writer.writeEndOfFile(); |
150 | |
151 | std::string the_file = oss.str(); |
152 | std::ofstream foo("output2.zip" ); |
153 | foo.write(the_file.c_str(), the_file.size()); |
154 | foo.close(); |
155 | |
156 | std::istringstream iss(the_file); |
157 | |
158 | // read records through readers |
159 | PyTorchStreamReader reader(&iss); |
160 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
161 | |
162 | reader.setShouldLoadDebugSymbol(false); |
163 | EXPECT_FALSE(reader.hasRecord("key1.debug_pkl" )); |
164 | at::DataPtr ptr; |
165 | size_t size; |
166 | std::tie(ptr, size) = reader.getRecord("key1.debug_pkl" ); |
167 | EXPECT_EQ(size, 0); |
168 | } |
169 | |
170 | } // namespace |
171 | } // namespace serialize |
172 | } // namespace caffe2 |
173 | |