1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/core/TensorOptions.h> |
4 | #include <torch/csrc/autograd/generated/variable_factories.h> |
5 | #include <torch/csrc/jit/api/module.h> |
6 | #include <torch/csrc/jit/mobile/import.h> |
7 | #include <torch/csrc/jit/mobile/import_data.h> |
8 | #include <torch/csrc/jit/mobile/module.h> |
9 | #include <torch/csrc/jit/mobile/train/export_data.h> |
10 | #include <torch/csrc/jit/mobile/train/optim/sgd.h> |
11 | #include <torch/csrc/jit/mobile/train/random.h> |
12 | #include <torch/csrc/jit/mobile/train/sequential.h> |
13 | #include <torch/csrc/jit/serialization/import.h> |
14 | #include <torch/data/dataloader.h> |
15 | #include <torch/torch.h> |
16 | |
17 | // Tests go in torch::jit |
18 | namespace torch { |
19 | namespace jit { |
20 | |
21 | TEST(LiteTrainerTest, Params) { |
22 | Module m("m" ); |
23 | m.register_parameter("foo" , torch::ones({1}, at::requires_grad()), false); |
24 | m.define(R"( |
25 | def forward(self, x): |
26 | b = 1.0 |
27 | return self.foo * x + b |
28 | )" ); |
29 | double learning_rate = 0.1, momentum = 0.1; |
30 | int n_epoc = 10; |
31 | // init: y = x + 1; |
32 | // target: y = 2 x + 1 |
33 | std::vector<std::pair<Tensor, Tensor>> trainData{ |
34 | {1 * torch::ones({1}), 3 * torch::ones({1})}, |
35 | }; |
36 | // Reference: Full jit |
37 | std::stringstream ms; |
38 | m.save(ms); |
39 | auto mm = load(ms); |
40 | // mm.train(); |
41 | std::vector<::at::Tensor> parameters; |
42 | for (auto parameter : mm.parameters()) { |
43 | parameters.emplace_back(parameter); |
44 | } |
45 | ::torch::optim::SGD optimizer( |
46 | parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); |
47 | for (int epoc = 0; epoc < n_epoc; ++epoc) { |
48 | for (auto& data : trainData) { |
49 | auto source = data.first, targets = data.second; |
50 | optimizer.zero_grad(); |
51 | std::vector<IValue> train_inputs{source}; |
52 | auto output = mm.forward(train_inputs).toTensor(); |
53 | auto loss = ::torch::l1_loss(output, targets); |
54 | loss.backward(); |
55 | optimizer.step(); |
56 | } |
57 | } |
58 | std::stringstream ss; |
59 | m._save_for_mobile(ss); |
60 | mobile::Module bc = _load_for_mobile(ss); |
61 | std::vector<::at::Tensor> bc_parameters = bc.parameters(); |
62 | ::torch::optim::SGD bc_optimizer( |
63 | bc_parameters, |
64 | ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); |
65 | for (int epoc = 0; epoc < n_epoc; ++epoc) { |
66 | for (auto& data : trainData) { |
67 | auto source = data.first, targets = data.second; |
68 | bc_optimizer.zero_grad(); |
69 | std::vector<IValue> train_inputs{source}; |
70 | auto output = bc.forward(train_inputs).toTensor(); |
71 | auto loss = ::torch::l1_loss(output, targets); |
72 | loss.backward(); |
73 | bc_optimizer.step(); |
74 | } |
75 | } |
76 | AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>()); |
77 | } |
78 | |
79 | // TODO Renable these tests after parameters are correctly loaded on mobile |
80 | /* |
81 | TEST(MobileTest, NamedParameters) { |
82 | Module m("m"); |
83 | m.register_parameter("foo", torch::ones({}), false); |
84 | m.define(R"( |
85 | def add_it(self, x): |
86 | b = 4 |
87 | return self.foo + x + b |
88 | )"); |
89 | Module child("m2"); |
90 | child.register_parameter("foo", 4 * torch::ones({}), false); |
91 | child.register_parameter("bar", 4 * torch::ones({}), false); |
92 | m.register_module("child1", child); |
93 | m.register_module("child2", child.clone()); |
94 | std::stringstream ss; |
95 | m._save_for_mobile(ss); |
96 | mobile::Module bc = _load_for_mobile(ss); |
97 | |
98 | auto full_params = m.named_parameters(); |
99 | auto mobile_params = bc.named_parameters(); |
100 | AT_ASSERT(full_params.size() == mobile_params.size()); |
101 | for (const auto& e : full_params) { |
102 | AT_ASSERT(e.value.item().toInt() == |
103 | mobile_params[e.name].item().toInt()); |
104 | } |
105 | } |
106 | |
107 | TEST(MobileTest, SaveLoadParameters) { |
108 | Module m("m"); |
109 | m.register_parameter("foo", torch::ones({}), false); |
110 | m.define(R"( |
111 | def add_it(self, x): |
112 | b = 4 |
113 | return self.foo + x + b |
114 | )"); |
115 | Module child("m2"); |
116 | child.register_parameter("foo", 4 * torch::ones({}), false); |
117 | child.register_parameter("bar", 3 * torch::ones({}), false); |
118 | m.register_module("child1", child); |
119 | m.register_module("child2", child.clone()); |
120 | auto full_params = m.named_parameters(); |
121 | std::stringstream ss; |
122 | std::stringstream ss_data; |
123 | m._save_for_mobile(ss); |
124 | |
125 | // load mobile module, save mobile named parameters |
126 | mobile::Module bc = _load_for_mobile(ss); |
127 | _save_parameters(bc.named_parameters(), ss_data); |
128 | |
129 | // load back the named parameters, compare to full-jit Module's |
130 | auto mobile_params = _load_parameters(ss_data); |
131 | AT_ASSERT(full_params.size() == mobile_params.size()); |
132 | for (const auto& e : full_params) { |
133 | AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>()); |
134 | } |
135 | } |
136 | */ |
137 | |
138 | TEST(MobileTest, SaveLoadParametersEmpty) { |
139 | Module m("m" ); |
140 | m.define(R"( |
141 | def add_it(self, x): |
142 | b = 4 |
143 | return x + b |
144 | )" ); |
145 | Module child("m2" ); |
146 | m.register_module("child1" , child); |
147 | m.register_module("child2" , child.clone()); |
148 | std::stringstream ss; |
149 | std::stringstream ss_data; |
150 | m._save_for_mobile(ss); |
151 | |
152 | // load mobile module, save mobile named parameters |
153 | mobile::Module bc = _load_for_mobile(ss); |
154 | _save_parameters(bc.named_parameters(), ss_data); |
155 | |
156 | // load back the named parameters, test is empty |
157 | auto mobile_params = _load_parameters(ss_data); |
158 | AT_ASSERT(mobile_params.size() == 0); |
159 | } |
160 | |
161 | TEST(MobileTest, SaveParametersDefaultsToZip) { |
162 | // Save some empty parameters. |
163 | std::map<std::string, at::Tensor> empty_parameters; |
164 | std::stringstream ss_data; |
165 | _save_parameters(empty_parameters, ss_data); |
166 | |
167 | // Verify that parameters were serialized to a ZIP container. |
168 | EXPECT_GE(ss_data.str().size(), 4); |
169 | EXPECT_EQ(ss_data.str()[0], 'P'); |
170 | EXPECT_EQ(ss_data.str()[1], 'K'); |
171 | EXPECT_EQ(ss_data.str()[2], '\x03'); |
172 | EXPECT_EQ(ss_data.str()[3], '\x04'); |
173 | } |
174 | |
175 | TEST(MobileTest, SaveParametersCanUseFlatbuffer) { |
176 | // Save some empty parameters using flatbuffer. |
177 | std::map<std::string, at::Tensor> empty_parameters; |
178 | std::stringstream ss_data; |
179 | _save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true); |
180 | |
181 | // Verify that parameters were serialized to a flatbuffer. The flatbuffer |
182 | // magic bytes should be at offsets 4..7. The first four bytes contain an |
183 | // offset to the actual flatbuffer data. |
184 | EXPECT_GE(ss_data.str().size(), 8); |
185 | EXPECT_EQ(ss_data.str()[4], 'P'); |
186 | EXPECT_EQ(ss_data.str()[5], 'T'); |
187 | EXPECT_EQ(ss_data.str()[6], 'M'); |
188 | EXPECT_EQ(ss_data.str()[7], 'F'); |
189 | } |
190 | |
191 | TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) { |
192 | // Create some simple parameters to save. |
193 | std::map<std::string, at::Tensor> input_params; |
194 | input_params["four_by_ones" ] = 4 * torch::ones({}); |
195 | input_params["three_by_ones" ] = 3 * torch::ones({}); |
196 | |
197 | // Serialize them using flatbuffers. |
198 | std::stringstream data; |
199 | _save_parameters(input_params, data, /*use_flatbuffer=*/true); |
200 | |
201 | // The flatbuffer magic bytes should be at offsets 4..7. |
202 | EXPECT_EQ(data.str()[4], 'P'); |
203 | EXPECT_EQ(data.str()[5], 'T'); |
204 | EXPECT_EQ(data.str()[6], 'M'); |
205 | EXPECT_EQ(data.str()[7], 'F'); |
206 | |
207 | // Read them back and check that they survived the trip. |
208 | auto output_params = _load_parameters(data); |
209 | EXPECT_EQ(output_params.size(), 2); |
210 | { |
211 | auto four_by_ones = 4 * torch::ones({}); |
212 | EXPECT_EQ( |
213 | output_params["four_by_ones" ].item<int>(), four_by_ones.item<int>()); |
214 | } |
215 | { |
216 | auto three_by_ones = 3 * torch::ones({}); |
217 | EXPECT_EQ( |
218 | output_params["three_by_ones" ].item<int>(), three_by_ones.item<int>()); |
219 | } |
220 | } |
221 | |
222 | TEST(MobileTest, LoadParametersUnexpectedFormatShouldThrow) { |
223 | // Manually create some data that doesn't look like a ZIP or Flatbuffer file. |
224 | // Make sure it's longer than 8 bytes, since getFileFormat() needs that much |
225 | // data to detect the type. |
226 | std::stringstream bad_data; |
227 | bad_data << "abcd" |
228 | << "efgh" |
229 | << "ijkl" ; |
230 | |
231 | // Loading parameters from it should throw an exception. |
232 | EXPECT_ANY_THROW(_load_parameters(bad_data)); |
233 | } |
234 | |
235 | TEST(MobileTest, LoadParametersEmptyDataShouldThrow) { |
236 | // Loading parameters from an empty data stream should throw an exception. |
237 | std::stringstream empty; |
238 | EXPECT_ANY_THROW(_load_parameters(empty)); |
239 | } |
240 | |
241 | TEST(LiteTrainerTest, SGD) { |
242 | Module m("m" ); |
243 | m.register_parameter("foo" , torch::ones({1}, at::requires_grad()), false); |
244 | m.define(R"( |
245 | def forward(self, x): |
246 | b = 1.0 |
247 | return self.foo * x + b |
248 | )" ); |
249 | double learning_rate = 0.1, momentum = 0.1; |
250 | int n_epoc = 10; |
251 | // init: y = x + 1; |
252 | // target: y = 2 x + 1 |
253 | std::vector<std::pair<Tensor, Tensor>> trainData{ |
254 | {1 * torch::ones({1}), 3 * torch::ones({1})}, |
255 | }; |
256 | // Reference: Full jit and torch::optim::SGD |
257 | std::stringstream ms; |
258 | m.save(ms); |
259 | auto mm = load(ms); |
260 | std::vector<::at::Tensor> parameters; |
261 | for (auto parameter : mm.parameters()) { |
262 | parameters.emplace_back(parameter); |
263 | } |
264 | ::torch::optim::SGD optimizer( |
265 | parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); |
266 | for (int epoc = 0; epoc < n_epoc; ++epoc) { |
267 | for (auto& data : trainData) { |
268 | auto source = data.first, targets = data.second; |
269 | optimizer.zero_grad(); |
270 | std::vector<IValue> train_inputs{source}; |
271 | auto output = mm.forward(train_inputs).toTensor(); |
272 | auto loss = ::torch::l1_loss(output, targets); |
273 | loss.backward(); |
274 | optimizer.step(); |
275 | } |
276 | } |
277 | // Test: lite interpreter and torch::jit::mobile::SGD |
278 | std::stringstream ss; |
279 | m._save_for_mobile(ss); |
280 | mobile::Module bc = _load_for_mobile(ss); |
281 | std::vector<::at::Tensor> bc_parameters = bc.parameters(); |
282 | ::torch::jit::mobile::SGD bc_optimizer( |
283 | bc_parameters, |
284 | ::torch::jit::mobile::SGDOptions(learning_rate).momentum(momentum)); |
285 | for (int epoc = 0; epoc < n_epoc; ++epoc) { |
286 | for (auto& data : trainData) { |
287 | auto source = data.first, targets = data.second; |
288 | bc_optimizer.zero_grad(); |
289 | std::vector<IValue> train_inputs{source}; |
290 | auto output = bc.forward(train_inputs).toTensor(); |
291 | auto loss = ::torch::l1_loss(output, targets); |
292 | loss.backward(); |
293 | bc_optimizer.step(); |
294 | } |
295 | } |
296 | AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>()); |
297 | } |
298 | |
299 | namespace { |
300 | struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> { |
301 | explicit DummyDataset(size_t size = 100) : size_(size) {} |
302 | |
303 | int get(size_t index) override { |
304 | // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
305 | return 1 + index; |
306 | } |
307 | torch::optional<size_t> size() const override { |
308 | return size_; |
309 | } |
310 | |
311 | size_t size_; |
312 | }; |
313 | } // namespace |
314 | |
315 | TEST(LiteTrainerTest, SequentialSampler) { |
316 | // test that sampler can be used with dataloader |
317 | const int kBatchSize = 10; |
318 | auto data_loader = torch::data::make_data_loader<mobile::SequentialSampler>( |
319 | DummyDataset(25), kBatchSize); |
320 | int i = 1; |
321 | for (const auto& batch : *data_loader) { |
322 | for (const auto& example : batch) { |
323 | AT_ASSERT(i == example); |
324 | i++; |
325 | } |
326 | } |
327 | } |
328 | |
329 | TEST(LiteTrainerTest, RandomSamplerReturnsIndicesInCorrectRange) { |
330 | mobile::RandomSampler sampler(10); |
331 | |
332 | std::vector<size_t> indices = sampler.next(3).value(); |
333 | for (auto i : indices) { |
334 | AT_ASSERT(i < 10); |
335 | } |
336 | |
337 | indices = sampler.next(5).value(); |
338 | for (auto i : indices) { |
339 | AT_ASSERT(i < 10); |
340 | } |
341 | |
342 | indices = sampler.next(2).value(); |
343 | for (auto i : indices) { |
344 | AT_ASSERT(i < 10); |
345 | } |
346 | |
347 | AT_ASSERT(sampler.next(10).has_value() == false); |
348 | } |
349 | |
350 | TEST(LiteTrainerTest, RandomSamplerReturnsLessValuesForLastBatch) { |
351 | mobile::RandomSampler sampler(5); |
352 | AT_ASSERT(sampler.next(3).value().size() == 3); |
353 | AT_ASSERT(sampler.next(100).value().size() == 2); |
354 | AT_ASSERT(sampler.next(2).has_value() == false); |
355 | } |
356 | |
357 | TEST(LiteTrainerTest, RandomSamplerResetsWell) { |
358 | mobile::RandomSampler sampler(5); |
359 | AT_ASSERT(sampler.next(5).value().size() == 5); |
360 | AT_ASSERT(sampler.next(2).has_value() == false); |
361 | sampler.reset(); |
362 | AT_ASSERT(sampler.next(5).value().size() == 5); |
363 | AT_ASSERT(sampler.next(2).has_value() == false); |
364 | } |
365 | |
366 | TEST(LiteTrainerTest, RandomSamplerResetsWithNewSizeWell) { |
367 | mobile::RandomSampler sampler(5); |
368 | AT_ASSERT(sampler.next(5).value().size() == 5); |
369 | AT_ASSERT(sampler.next(2).has_value() == false); |
370 | sampler.reset(7); |
371 | AT_ASSERT(sampler.next(7).value().size() == 7); |
372 | AT_ASSERT(sampler.next(2).has_value() == false); |
373 | sampler.reset(3); |
374 | AT_ASSERT(sampler.next(3).value().size() == 3); |
375 | AT_ASSERT(sampler.next(2).has_value() == false); |
376 | } |
377 | |
378 | } // namespace jit |
379 | } // namespace torch |
380 | |