1#include <cstdio>
2#include <string>
3#include <array>
4
5#include <gtest/gtest.h>
6
7#include "caffe2/serialize/inline_container.h"
8#include "c10/util/irange.h"
9
10namespace caffe2 {
11namespace serialize {
12namespace {
13
14TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
15 int64_t kFieldAlignment = 64L;
16
17 std::ostringstream oss;
18 // write records through writers
19 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
20 oss.write(static_cast<const char*>(b), n);
21 return oss ? n : 0;
22 });
23 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
24 std::array<char, 127> data1;
25
26 for (auto i: c10::irange( data1.size())) {
27 data1[i] = data1.size() - i;
28 }
29 writer.writeRecord("key1", data1.data(), data1.size());
30
31 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
32 std::array<char, 64> data2;
33 for (auto i: c10::irange(data2.size())) {
34 data2[i] = data2.size() - i;
35 }
36 writer.writeRecord("key2", data2.data(), data2.size());
37
38 const std::unordered_set<std::string>& written_records =
39 writer.getAllWrittenRecords();
40 ASSERT_EQ(written_records.size(), 2);
41 ASSERT_EQ(written_records.count("key1"), 1);
42 ASSERT_EQ(written_records.count("key2"), 1);
43
44 writer.writeEndOfFile();
45
46 std::string the_file = oss.str();
47 std::ofstream foo("output.zip");
48 foo.write(the_file.c_str(), the_file.size());
49 foo.close();
50
51 std::istringstream iss(the_file);
52
53 // read records through readers
54 PyTorchStreamReader reader(&iss);
55 ASSERT_TRUE(reader.hasRecord("key1"));
56 ASSERT_TRUE(reader.hasRecord("key2"));
57 ASSERT_FALSE(reader.hasRecord("key2000"));
58 at::DataPtr data_ptr;
59 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
60 int64_t size;
61 std::tie(data_ptr, size) = reader.getRecord("key1");
62 size_t off1 = reader.getRecordOffset("key1");
63 ASSERT_EQ(size, data1.size());
64 ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
65 ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
66 ASSERT_EQ(off1 % kFieldAlignment, 0);
67
68 std::tie(data_ptr, size) = reader.getRecord("key2");
69 size_t off2 = reader.getRecordOffset("key2");
70 ASSERT_EQ(off2 % kFieldAlignment, 0);
71
72 ASSERT_EQ(size, data2.size());
73 ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
74 ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
75}
76
77TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
78 std::ostringstream oss;
79 // write records through writers
80 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
81 oss.write(static_cast<const char*>(b), n);
82 return oss ? n : 0;
83 });
84 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
85 std::array<char, 127> data1;
86
87 for (auto i: c10::irange(data1.size())) {
88 data1[i] = data1.size() - i;
89 }
90 writer.writeRecord("key1", data1.data(), data1.size());
91
92 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
93 std::array<char, 64> data2;
94 for (auto i: c10::irange(data2.size())) {
95 data2[i] = data2.size() - i;
96 }
97 writer.writeRecord("key2", data2.data(), data2.size());
98
99 const std::unordered_set<std::string>& written_records =
100 writer.getAllWrittenRecords();
101 ASSERT_EQ(written_records.size(), 2);
102 ASSERT_EQ(written_records.count("key1"), 1);
103 ASSERT_EQ(written_records.count("key2"), 1);
104
105 writer.writeEndOfFile();
106
107 std::string the_file = oss.str();
108 std::ofstream foo("output2.zip");
109 foo.write(the_file.c_str(), the_file.size());
110 foo.close();
111
112 std::istringstream iss(the_file);
113
114 // read records through readers
115 PyTorchStreamReader reader(&iss);
116 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
117 EXPECT_THROW(reader.getRecord("key3"), c10::Error);
118
119 // Reader should still work after throwing
120 EXPECT_TRUE(reader.hasRecord("key1"));
121}
122
123TEST(PytorchStreamWriterAndReader, SkipDebugRecords) {
124 std::ostringstream oss;
125 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
126 oss.write(static_cast<const char*>(b), n);
127 return oss ? n : 0;
128 });
129 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
130 std::array<char, 127> data1;
131
132 for (auto i: c10::irange(data1.size())) {
133 data1[i] = data1.size() - i;
134 }
135 writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
136
137 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
138 std::array<char, 64> data2;
139 for (auto i: c10::irange(data2.size())) {
140 data2[i] = data2.size() - i;
141 }
142 writer.writeRecord("key2.debug_pkl", data2.data(), data2.size());
143
144 const std::unordered_set<std::string>& written_records =
145 writer.getAllWrittenRecords();
146 ASSERT_EQ(written_records.size(), 2);
147 ASSERT_EQ(written_records.count("key1.debug_pkl"), 1);
148 ASSERT_EQ(written_records.count("key2.debug_pkl"), 1);
149 writer.writeEndOfFile();
150
151 std::string the_file = oss.str();
152 std::ofstream foo("output2.zip");
153 foo.write(the_file.c_str(), the_file.size());
154 foo.close();
155
156 std::istringstream iss(the_file);
157
158 // read records through readers
159 PyTorchStreamReader reader(&iss);
160 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
161
162 reader.setShouldLoadDebugSymbol(false);
163 EXPECT_FALSE(reader.hasRecord("key1.debug_pkl"));
164 at::DataPtr ptr;
165 size_t size;
166 std::tie(ptr, size) = reader.getRecord("key1.debug_pkl");
167 EXPECT_EQ(size, 0);
168}
169
170} // namespace
171} // namespace serialize
172} // namespace caffe2
173