1#include <torch/csrc/jit/mobile/file_format.h>
2
3#include <gtest/gtest.h>
4
5#include <sstream>
6
7// Tests go in torch::jit
8namespace torch {
9namespace jit {
10
11TEST(FileFormatTest, IdentifiesFlatbufferStream) {
12 // Create data whose initial bytes look like a Flatbuffer stream.
13 std::stringstream data;
14 data << "abcd" // First four bytes don't matter.
15 << "PTMF" // Magic string.
16 << "efgh"; // Trailing bytes don't matter.
17
18 // The data should be identified as Flatbuffer.
19 EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
20}
21
22TEST(FileFormatTest, IdentifiesZipStream) {
23 // Create data whose initial bytes look like a ZIP stream.
24 std::stringstream data;
25 data << "PK\x03\x04" // Magic string.
26 << "abcd" // Trailing bytes don't matter.
27 << "efgh";
28
29 // The data should be identified as ZIP.
30 EXPECT_EQ(getFileFormat(data), FileFormat::ZipFileFormat);
31}
32
33TEST(FileFormatTest, FlatbufferTakesPrecedence) {
34 // Since the Flatbuffer and ZIP magic bytes are at different offsets,
35 // the same data could be identified as both. Demonstrate that Flatbuffer
36 // takes precedence. (See details in file_format.h)
37 std::stringstream data;
38 data << "PK\x03\x04" // ZIP magic string.
39 << "PTMF" // Flatbuffer magic string.
40 << "abcd"; // Trailing bytes don't matter.
41
42 // The data should be identified as Flatbuffer.
43 EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
44}
45
46TEST(FileFormatTest, HandlesUnknownStream) {
47 // Create data that doesn't look like any known format.
48 std::stringstream data;
49 data << "abcd"
50 << "efgh"
51 << "ijkl";
52
53 // The data should be classified as unknown.
54 EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
55}
56
57TEST(FileFormatTest, ShortStreamIsUnknown) {
58 // Create data with fewer than kFileFormatHeaderSize (8) bytes.
59 std::stringstream data;
60 data << "ABCD";
61
62 // The data should be classified as unknown.
63 EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
64}
65
66TEST(FileFormatTest, EmptyStreamIsUnknown) {
67 // Create an empty stream.
68 std::stringstream data;
69
70 // The data should be classified as unknown.
71 EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
72}
73
74TEST(FileFormatTest, BadStreamIsUnknown) {
75 // Create a stream with valid Flatbuffer data.
76 std::stringstream data;
77 data << "abcd"
78 << "PTMF" // Flatbuffer magic string.
79 << "efgh";
80
81 // Demonstrate that the data would normally be identified as Flatbuffer.
82 EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
83
84 // Mark the stream as bad, and demonstrate that it is in an error state.
85 data.setstate(std::stringstream::badbit);
86 // Demonstrate that the stream is in an error state.
87 EXPECT_FALSE(data.good());
88
89 // The data should now be classified as unknown.
90 EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
91}
92
93TEST(FileFormatTest, StreamOffsetIsObservedAndRestored) {
94 // Create data with a Flatbuffer header at a non-zero offset into the stream.
95 std::stringstream data;
96 // Add initial padding.
97 data << "PADDING";
98 size_t offset = data.str().size();
99 // Add a valid Flatbuffer header.
100 data << "abcd"
101 << "PTMF" // Flatbuffer magic string.
102 << "efgh";
103 // Seek just after the padding.
104 data.seekg(static_cast<std::stringstream::off_type>(offset), data.beg);
105 // Demonstrate that the stream points to the beginning of the Flatbuffer data,
106 // not to the padding.
107 EXPECT_EQ(data.peek(), 'a');
108
109 // The data should be identified as Flatbuffer.
110 EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
111
112 // The stream position should be where it was before identification.
113 EXPECT_EQ(offset, data.tellg());
114}
115
116TEST(FileFormatTest, HandlesMissingFile) {
117 // A missing file should be classified as unknown.
118 EXPECT_EQ(
119 getFileFormat("NON_EXISTENT_FILE_4965c363-44a7-443c-983a-8895eead0277"),
120 FileFormat::UnknownFileFormat);
121}
122
123} // namespace jit
124} // namespace torch
125