1 | #include <torch/csrc/jit/mobile/file_format.h> |
2 | |
3 | #include <gtest/gtest.h> |
4 | |
5 | #include <sstream> |
6 | |
7 | // Tests go in torch::jit |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | TEST(FileFormatTest, IdentifiesFlatbufferStream) { |
12 | // Create data whose initial bytes look like a Flatbuffer stream. |
13 | std::stringstream data; |
14 | data << "abcd" // First four bytes don't matter. |
15 | << "PTMF" // Magic string. |
16 | << "efgh" ; // Trailing bytes don't matter. |
17 | |
18 | // The data should be identified as Flatbuffer. |
19 | EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat); |
20 | } |
21 | |
22 | TEST(FileFormatTest, IdentifiesZipStream) { |
23 | // Create data whose initial bytes look like a ZIP stream. |
24 | std::stringstream data; |
25 | data << "PK\x03\x04" // Magic string. |
26 | << "abcd" // Trailing bytes don't matter. |
27 | << "efgh" ; |
28 | |
29 | // The data should be identified as ZIP. |
30 | EXPECT_EQ(getFileFormat(data), FileFormat::ZipFileFormat); |
31 | } |
32 | |
33 | TEST(FileFormatTest, FlatbufferTakesPrecedence) { |
34 | // Since the Flatbuffer and ZIP magic bytes are at different offsets, |
35 | // the same data could be identified as both. Demonstrate that Flatbuffer |
36 | // takes precedence. (See details in file_format.h) |
37 | std::stringstream data; |
38 | data << "PK\x03\x04" // ZIP magic string. |
39 | << "PTMF" // Flatbuffer magic string. |
40 | << "abcd" ; // Trailing bytes don't matter. |
41 | |
42 | // The data should be identified as Flatbuffer. |
43 | EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat); |
44 | } |
45 | |
46 | TEST(FileFormatTest, HandlesUnknownStream) { |
47 | // Create data that doesn't look like any known format. |
48 | std::stringstream data; |
49 | data << "abcd" |
50 | << "efgh" |
51 | << "ijkl" ; |
52 | |
53 | // The data should be classified as unknown. |
54 | EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat); |
55 | } |
56 | |
57 | TEST(FileFormatTest, ShortStreamIsUnknown) { |
58 | // Create data with fewer than kFileFormatHeaderSize (8) bytes. |
59 | std::stringstream data; |
60 | data << "ABCD" ; |
61 | |
62 | // The data should be classified as unknown. |
63 | EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat); |
64 | } |
65 | |
66 | TEST(FileFormatTest, EmptyStreamIsUnknown) { |
67 | // Create an empty stream. |
68 | std::stringstream data; |
69 | |
70 | // The data should be classified as unknown. |
71 | EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat); |
72 | } |
73 | |
74 | TEST(FileFormatTest, BadStreamIsUnknown) { |
75 | // Create a stream with valid Flatbuffer data. |
76 | std::stringstream data; |
77 | data << "abcd" |
78 | << "PTMF" // Flatbuffer magic string. |
79 | << "efgh" ; |
80 | |
81 | // Demonstrate that the data would normally be identified as Flatbuffer. |
82 | EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat); |
83 | |
84 | // Mark the stream as bad, and demonstrate that it is in an error state. |
85 | data.setstate(std::stringstream::badbit); |
86 | // Demonstrate that the stream is in an error state. |
87 | EXPECT_FALSE(data.good()); |
88 | |
89 | // The data should now be classified as unknown. |
90 | EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat); |
91 | } |
92 | |
93 | TEST(FileFormatTest, StreamOffsetIsObservedAndRestored) { |
94 | // Create data with a Flatbuffer header at a non-zero offset into the stream. |
95 | std::stringstream data; |
96 | // Add initial padding. |
97 | data << "PADDING" ; |
98 | size_t offset = data.str().size(); |
99 | // Add a valid Flatbuffer header. |
100 | data << "abcd" |
101 | << "PTMF" // Flatbuffer magic string. |
102 | << "efgh" ; |
103 | // Seek just after the padding. |
104 | data.seekg(static_cast<std::stringstream::off_type>(offset), data.beg); |
105 | // Demonstrate that the stream points to the beginning of the Flatbuffer data, |
106 | // not to the padding. |
107 | EXPECT_EQ(data.peek(), 'a'); |
108 | |
109 | // The data should be identified as Flatbuffer. |
110 | EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat); |
111 | |
112 | // The stream position should be where it was before identification. |
113 | EXPECT_EQ(offset, data.tellg()); |
114 | } |
115 | |
116 | TEST(FileFormatTest, HandlesMissingFile) { |
117 | // A missing file should be classified as unknown. |
118 | EXPECT_EQ( |
119 | getFileFormat("NON_EXISTENT_FILE_4965c363-44a7-443c-983a-8895eead0277" ), |
120 | FileFormat::UnknownFileFormat); |
121 | } |
122 | |
123 | } // namespace jit |
124 | } // namespace torch |
125 | |