1#include <gtest/gtest.h>
2
3#include <c10/core/TensorOptions.h>
4#include <torch/csrc/autograd/generated/variable_factories.h>
5#include <torch/csrc/jit/api/module.h>
6#include <torch/csrc/jit/mobile/import.h>
7#include <torch/csrc/jit/mobile/import_data.h>
8#include <torch/csrc/jit/mobile/module.h>
9#include <torch/csrc/jit/mobile/train/export_data.h>
10#include <torch/csrc/jit/mobile/train/optim/sgd.h>
11#include <torch/csrc/jit/mobile/train/random.h>
12#include <torch/csrc/jit/mobile/train/sequential.h>
13#include <torch/csrc/jit/serialization/import.h>
14#include <torch/data/dataloader.h>
15#include <torch/torch.h>
16
17// Tests go in torch::jit
18namespace torch {
19namespace jit {
20
21TEST(LiteTrainerTest, Params) {
22 Module m("m");
23 m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
24 m.define(R"(
25 def forward(self, x):
26 b = 1.0
27 return self.foo * x + b
28 )");
29 double learning_rate = 0.1, momentum = 0.1;
30 int n_epoc = 10;
31 // init: y = x + 1;
32 // target: y = 2 x + 1
33 std::vector<std::pair<Tensor, Tensor>> trainData{
34 {1 * torch::ones({1}), 3 * torch::ones({1})},
35 };
36 // Reference: Full jit
37 std::stringstream ms;
38 m.save(ms);
39 auto mm = load(ms);
40 // mm.train();
41 std::vector<::at::Tensor> parameters;
42 for (auto parameter : mm.parameters()) {
43 parameters.emplace_back(parameter);
44 }
45 ::torch::optim::SGD optimizer(
46 parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
47 for (int epoc = 0; epoc < n_epoc; ++epoc) {
48 for (auto& data : trainData) {
49 auto source = data.first, targets = data.second;
50 optimizer.zero_grad();
51 std::vector<IValue> train_inputs{source};
52 auto output = mm.forward(train_inputs).toTensor();
53 auto loss = ::torch::l1_loss(output, targets);
54 loss.backward();
55 optimizer.step();
56 }
57 }
58 std::stringstream ss;
59 m._save_for_mobile(ss);
60 mobile::Module bc = _load_for_mobile(ss);
61 std::vector<::at::Tensor> bc_parameters = bc.parameters();
62 ::torch::optim::SGD bc_optimizer(
63 bc_parameters,
64 ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
65 for (int epoc = 0; epoc < n_epoc; ++epoc) {
66 for (auto& data : trainData) {
67 auto source = data.first, targets = data.second;
68 bc_optimizer.zero_grad();
69 std::vector<IValue> train_inputs{source};
70 auto output = bc.forward(train_inputs).toTensor();
71 auto loss = ::torch::l1_loss(output, targets);
72 loss.backward();
73 bc_optimizer.step();
74 }
75 }
76 AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
77}
78
79// TODO Renable these tests after parameters are correctly loaded on mobile
80/*
81TEST(MobileTest, NamedParameters) {
82 Module m("m");
83 m.register_parameter("foo", torch::ones({}), false);
84 m.define(R"(
85 def add_it(self, x):
86 b = 4
87 return self.foo + x + b
88 )");
89 Module child("m2");
90 child.register_parameter("foo", 4 * torch::ones({}), false);
91 child.register_parameter("bar", 4 * torch::ones({}), false);
92 m.register_module("child1", child);
93 m.register_module("child2", child.clone());
94 std::stringstream ss;
95 m._save_for_mobile(ss);
96 mobile::Module bc = _load_for_mobile(ss);
97
98 auto full_params = m.named_parameters();
99 auto mobile_params = bc.named_parameters();
100 AT_ASSERT(full_params.size() == mobile_params.size());
101 for (const auto& e : full_params) {
102 AT_ASSERT(e.value.item().toInt() ==
103 mobile_params[e.name].item().toInt());
104 }
105}
106
107TEST(MobileTest, SaveLoadParameters) {
108 Module m("m");
109 m.register_parameter("foo", torch::ones({}), false);
110 m.define(R"(
111 def add_it(self, x):
112 b = 4
113 return self.foo + x + b
114 )");
115 Module child("m2");
116 child.register_parameter("foo", 4 * torch::ones({}), false);
117 child.register_parameter("bar", 3 * torch::ones({}), false);
118 m.register_module("child1", child);
119 m.register_module("child2", child.clone());
120 auto full_params = m.named_parameters();
121 std::stringstream ss;
122 std::stringstream ss_data;
123 m._save_for_mobile(ss);
124
125 // load mobile module, save mobile named parameters
126 mobile::Module bc = _load_for_mobile(ss);
127 _save_parameters(bc.named_parameters(), ss_data);
128
129 // load back the named parameters, compare to full-jit Module's
130 auto mobile_params = _load_parameters(ss_data);
131 AT_ASSERT(full_params.size() == mobile_params.size());
132 for (const auto& e : full_params) {
133 AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>());
134 }
135}
136*/
137
138TEST(MobileTest, SaveLoadParametersEmpty) {
139 Module m("m");
140 m.define(R"(
141 def add_it(self, x):
142 b = 4
143 return x + b
144 )");
145 Module child("m2");
146 m.register_module("child1", child);
147 m.register_module("child2", child.clone());
148 std::stringstream ss;
149 std::stringstream ss_data;
150 m._save_for_mobile(ss);
151
152 // load mobile module, save mobile named parameters
153 mobile::Module bc = _load_for_mobile(ss);
154 _save_parameters(bc.named_parameters(), ss_data);
155
156 // load back the named parameters, test is empty
157 auto mobile_params = _load_parameters(ss_data);
158 AT_ASSERT(mobile_params.size() == 0);
159}
160
161TEST(MobileTest, SaveParametersDefaultsToZip) {
162 // Save some empty parameters.
163 std::map<std::string, at::Tensor> empty_parameters;
164 std::stringstream ss_data;
165 _save_parameters(empty_parameters, ss_data);
166
167 // Verify that parameters were serialized to a ZIP container.
168 EXPECT_GE(ss_data.str().size(), 4);
169 EXPECT_EQ(ss_data.str()[0], 'P');
170 EXPECT_EQ(ss_data.str()[1], 'K');
171 EXPECT_EQ(ss_data.str()[2], '\x03');
172 EXPECT_EQ(ss_data.str()[3], '\x04');
173}
174
175TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
176 // Save some empty parameters using flatbuffer.
177 std::map<std::string, at::Tensor> empty_parameters;
178 std::stringstream ss_data;
179 _save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
180
181 // Verify that parameters were serialized to a flatbuffer. The flatbuffer
182 // magic bytes should be at offsets 4..7. The first four bytes contain an
183 // offset to the actual flatbuffer data.
184 EXPECT_GE(ss_data.str().size(), 8);
185 EXPECT_EQ(ss_data.str()[4], 'P');
186 EXPECT_EQ(ss_data.str()[5], 'T');
187 EXPECT_EQ(ss_data.str()[6], 'M');
188 EXPECT_EQ(ss_data.str()[7], 'F');
189}
190
191TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
192 // Create some simple parameters to save.
193 std::map<std::string, at::Tensor> input_params;
194 input_params["four_by_ones"] = 4 * torch::ones({});
195 input_params["three_by_ones"] = 3 * torch::ones({});
196
197 // Serialize them using flatbuffers.
198 std::stringstream data;
199 _save_parameters(input_params, data, /*use_flatbuffer=*/true);
200
201 // The flatbuffer magic bytes should be at offsets 4..7.
202 EXPECT_EQ(data.str()[4], 'P');
203 EXPECT_EQ(data.str()[5], 'T');
204 EXPECT_EQ(data.str()[6], 'M');
205 EXPECT_EQ(data.str()[7], 'F');
206
207 // Read them back and check that they survived the trip.
208 auto output_params = _load_parameters(data);
209 EXPECT_EQ(output_params.size(), 2);
210 {
211 auto four_by_ones = 4 * torch::ones({});
212 EXPECT_EQ(
213 output_params["four_by_ones"].item<int>(), four_by_ones.item<int>());
214 }
215 {
216 auto three_by_ones = 3 * torch::ones({});
217 EXPECT_EQ(
218 output_params["three_by_ones"].item<int>(), three_by_ones.item<int>());
219 }
220}
221
222TEST(MobileTest, LoadParametersUnexpectedFormatShouldThrow) {
223 // Manually create some data that doesn't look like a ZIP or Flatbuffer file.
224 // Make sure it's longer than 8 bytes, since getFileFormat() needs that much
225 // data to detect the type.
226 std::stringstream bad_data;
227 bad_data << "abcd"
228 << "efgh"
229 << "ijkl";
230
231 // Loading parameters from it should throw an exception.
232 EXPECT_ANY_THROW(_load_parameters(bad_data));
233}
234
235TEST(MobileTest, LoadParametersEmptyDataShouldThrow) {
236 // Loading parameters from an empty data stream should throw an exception.
237 std::stringstream empty;
238 EXPECT_ANY_THROW(_load_parameters(empty));
239}
240
241TEST(LiteTrainerTest, SGD) {
242 Module m("m");
243 m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
244 m.define(R"(
245 def forward(self, x):
246 b = 1.0
247 return self.foo * x + b
248 )");
249 double learning_rate = 0.1, momentum = 0.1;
250 int n_epoc = 10;
251 // init: y = x + 1;
252 // target: y = 2 x + 1
253 std::vector<std::pair<Tensor, Tensor>> trainData{
254 {1 * torch::ones({1}), 3 * torch::ones({1})},
255 };
256 // Reference: Full jit and torch::optim::SGD
257 std::stringstream ms;
258 m.save(ms);
259 auto mm = load(ms);
260 std::vector<::at::Tensor> parameters;
261 for (auto parameter : mm.parameters()) {
262 parameters.emplace_back(parameter);
263 }
264 ::torch::optim::SGD optimizer(
265 parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
266 for (int epoc = 0; epoc < n_epoc; ++epoc) {
267 for (auto& data : trainData) {
268 auto source = data.first, targets = data.second;
269 optimizer.zero_grad();
270 std::vector<IValue> train_inputs{source};
271 auto output = mm.forward(train_inputs).toTensor();
272 auto loss = ::torch::l1_loss(output, targets);
273 loss.backward();
274 optimizer.step();
275 }
276 }
277 // Test: lite interpreter and torch::jit::mobile::SGD
278 std::stringstream ss;
279 m._save_for_mobile(ss);
280 mobile::Module bc = _load_for_mobile(ss);
281 std::vector<::at::Tensor> bc_parameters = bc.parameters();
282 ::torch::jit::mobile::SGD bc_optimizer(
283 bc_parameters,
284 ::torch::jit::mobile::SGDOptions(learning_rate).momentum(momentum));
285 for (int epoc = 0; epoc < n_epoc; ++epoc) {
286 for (auto& data : trainData) {
287 auto source = data.first, targets = data.second;
288 bc_optimizer.zero_grad();
289 std::vector<IValue> train_inputs{source};
290 auto output = bc.forward(train_inputs).toTensor();
291 auto loss = ::torch::l1_loss(output, targets);
292 loss.backward();
293 bc_optimizer.step();
294 }
295 }
296 AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
297}
298
299namespace {
300struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> {
301 explicit DummyDataset(size_t size = 100) : size_(size) {}
302
303 int get(size_t index) override {
304 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
305 return 1 + index;
306 }
307 torch::optional<size_t> size() const override {
308 return size_;
309 }
310
311 size_t size_;
312};
313} // namespace
314
315TEST(LiteTrainerTest, SequentialSampler) {
316 // test that sampler can be used with dataloader
317 const int kBatchSize = 10;
318 auto data_loader = torch::data::make_data_loader<mobile::SequentialSampler>(
319 DummyDataset(25), kBatchSize);
320 int i = 1;
321 for (const auto& batch : *data_loader) {
322 for (const auto& example : batch) {
323 AT_ASSERT(i == example);
324 i++;
325 }
326 }
327}
328
329TEST(LiteTrainerTest, RandomSamplerReturnsIndicesInCorrectRange) {
330 mobile::RandomSampler sampler(10);
331
332 std::vector<size_t> indices = sampler.next(3).value();
333 for (auto i : indices) {
334 AT_ASSERT(i < 10);
335 }
336
337 indices = sampler.next(5).value();
338 for (auto i : indices) {
339 AT_ASSERT(i < 10);
340 }
341
342 indices = sampler.next(2).value();
343 for (auto i : indices) {
344 AT_ASSERT(i < 10);
345 }
346
347 AT_ASSERT(sampler.next(10).has_value() == false);
348}
349
350TEST(LiteTrainerTest, RandomSamplerReturnsLessValuesForLastBatch) {
351 mobile::RandomSampler sampler(5);
352 AT_ASSERT(sampler.next(3).value().size() == 3);
353 AT_ASSERT(sampler.next(100).value().size() == 2);
354 AT_ASSERT(sampler.next(2).has_value() == false);
355}
356
357TEST(LiteTrainerTest, RandomSamplerResetsWell) {
358 mobile::RandomSampler sampler(5);
359 AT_ASSERT(sampler.next(5).value().size() == 5);
360 AT_ASSERT(sampler.next(2).has_value() == false);
361 sampler.reset();
362 AT_ASSERT(sampler.next(5).value().size() == 5);
363 AT_ASSERT(sampler.next(2).has_value() == false);
364}
365
366TEST(LiteTrainerTest, RandomSamplerResetsWithNewSizeWell) {
367 mobile::RandomSampler sampler(5);
368 AT_ASSERT(sampler.next(5).value().size() == 5);
369 AT_ASSERT(sampler.next(2).has_value() == false);
370 sampler.reset(7);
371 AT_ASSERT(sampler.next(7).value().size() == 7);
372 AT_ASSERT(sampler.next(2).has_value() == false);
373 sampler.reset(3);
374 AT_ASSERT(sampler.next(3).value().size() == 3);
375 AT_ASSERT(sampler.next(2).has_value() == false);
376}
377
378} // namespace jit
379} // namespace torch
380