1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/torch.h> |
4 | |
5 | #include <test/cpp/api/support.h> |
6 | |
7 | #include <c10/util/ArrayRef.h> |
8 | #include <c10/util/irange.h> |
9 | #include <c10/util/tempfile.h> |
10 | |
11 | #include <algorithm> |
12 | #include <chrono> |
13 | #include <future> |
14 | #include <iostream> |
15 | #include <iterator> |
16 | #include <limits> |
17 | #include <mutex> |
18 | #include <numeric> |
19 | #include <stdexcept> |
20 | #include <string> |
21 | #include <thread> |
22 | #include <unordered_set> |
23 | #include <vector> |
24 | |
25 | using namespace torch::data; // NOLINT |
26 | |
27 | const std::chrono::milliseconds kMillisecond(1); |
28 | |
29 | struct DummyDataset : datasets::Dataset<DummyDataset, int> { |
30 | explicit DummyDataset(size_t size = 100) : size_(size) {} |
31 | |
32 | int get(size_t index) override { |
33 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
34 | return 1 + index; |
35 | } |
36 | torch::optional<size_t> size() const override { |
37 | return size_; |
38 | } |
39 | |
40 | size_t size_; |
41 | }; |
42 | |
43 | TEST(DataTest, DatasetCallsGetCorrectly) { |
44 | DummyDataset d; |
45 | std::vector<int> batch = d.get_batch({0, 1, 2, 3, 4}); |
46 | std::vector<int> expected = {1, 2, 3, 4, 5}; |
47 | ASSERT_EQ(batch, expected); |
48 | } |
49 | |
50 | TEST(DataTest, TransformCallsGetApplyCorrectly) { |
51 | struct T : transforms::Transform<int, std::string> { |
52 | std::string apply(int input) override { |
53 | return std::to_string(input); |
54 | } |
55 | }; |
56 | |
57 | auto d = DummyDataset{}.map(T{}); |
58 | std::vector<std::string> batch = d.get_batch({0, 1, 2, 3, 4}); |
59 | std::vector<std::string> expected = {"1" , "2" , "3" , "4" , "5" }; |
60 | ASSERT_EQ(batch, expected); |
61 | } |
62 | |
63 | // dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk |
64 | // contains 10, 5, 20 examples respectively. |
65 | |
66 | struct DummyChunkDataReader : public datasets::ChunkDataReader<int> { |
67 | public: |
68 | using BatchType = datasets::ChunkDataReader<int>::ChunkType; |
69 | using DataType = datasets::ChunkDataReader<int>::ExampleType; |
70 | |
71 | /// Read an entire chunk. |
72 | BatchType read_chunk(size_t chunk_index) override { |
73 | BatchType batch_data; |
74 | int start_index = chunk_index == 0 |
75 | ? 0 |
76 | // NOLINTNEXTLINE(bugprone-fold-init-type) |
77 | : std::accumulate(chunk_sizes, chunk_sizes + chunk_index, 0); |
78 | |
79 | batch_data.resize(chunk_sizes[chunk_index]); |
80 | |
81 | std::iota(batch_data.begin(), batch_data.end(), start_index); |
82 | |
83 | return batch_data; |
84 | } |
85 | |
86 | size_t chunk_count() override { |
87 | return chunk_count_; |
88 | }; |
89 | |
90 | void reset() override{}; |
91 | |
92 | const static size_t chunk_count_ = 3; |
93 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) |
94 | size_t chunk_sizes[chunk_count_] = {10, 5, 20}; |
95 | }; |
96 | |
97 | TEST(DataTest, ChunkDataSetWithInvalidInitParameter) { |
98 | DummyChunkDataReader data_reader; |
99 | samplers::SequentialSampler sampler(0); |
100 | |
101 | auto initialization_function = [&](size_t preloader_count, |
102 | size_t batch_size, |
103 | size_t cache_size, |
104 | size_t cross_chunk_shuffle_count = 1) { |
105 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
106 | DummyChunkDataReader, |
107 | samplers::SequentialSampler, |
108 | samplers::SequentialSampler>> |
109 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
110 | DummyChunkDataReader, |
111 | samplers::SequentialSampler, |
112 | samplers::SequentialSampler>>( |
113 | data_reader, |
114 | sampler, |
115 | sampler, |
116 | datasets::ChunkDatasetOptions( |
117 | preloader_count, |
118 | batch_size, |
119 | cache_size, |
120 | cross_chunk_shuffle_count)); |
121 | }; |
122 | |
123 | ASSERT_THROWS_WITH( |
124 | initialization_function(0, 1, 1), |
125 | "Preloader count is 0. At least one preloader needs to be specified." ); |
126 | |
127 | ASSERT_THROWS_WITH( |
128 | initialization_function(1, 0, 1), |
129 | "Batch size is 0. A positive batch size needs to be specified." ); |
130 | |
131 | ASSERT_THROWS_WITH( |
132 | initialization_function(1, 1, 0), |
133 | "Cache size is 0. A positive cache size needs to be specified." ); |
134 | |
135 | ASSERT_THROWS_WITH( |
136 | initialization_function(1, 10, 5), |
137 | "Cache size is less than batch size. Cache needs to be large enough to " |
138 | "hold at least one batch." ); |
139 | ASSERT_THROWS_WITH( |
140 | initialization_function(1, 10, 20, 0), |
141 | "cross_chunk_shuffle_count needs to be greater than 0." ); |
142 | } |
143 | |
144 | struct InfiniteStreamDataset |
145 | : datasets::StreamDataset<InfiniteStreamDataset, std::vector<int>> { |
146 | std::vector<int> get_batch(size_t batch_size) override { |
147 | std::vector<int> batch(batch_size); |
148 | for (auto& i : batch) { |
149 | i = counter++; |
150 | } |
151 | return batch; |
152 | } |
153 | |
154 | torch::optional<size_t> size() const override { |
155 | return torch::nullopt; |
156 | } |
157 | |
158 | size_t counter = 0; |
159 | }; |
160 | |
161 | TEST(DataTest, InfiniteStreamDataset) { |
162 | const size_t kBatchSize = 13; |
163 | |
164 | auto dataset = InfiniteStreamDataset().map( |
165 | transforms::Lambda<int>([](int x) { return x + 1; })); |
166 | |
167 | auto data_loader = torch::data::make_data_loader( |
168 | std::move(dataset), |
169 | samplers::StreamSampler(/*epoch_size=*/39), |
170 | kBatchSize); |
171 | |
172 | size_t batch_index = 0; |
173 | for (auto& batch : *data_loader) { |
174 | ASSERT_LT(batch_index, 3); |
175 | ASSERT_EQ(batch.size(), kBatchSize); |
176 | for (const auto j : c10::irange(kBatchSize)) { |
177 | ASSERT_EQ(batch.at(j), 1 + (batch_index * kBatchSize) + j); |
178 | } |
179 | batch_index += 1; |
180 | } |
181 | ASSERT_EQ(batch_index, 3); |
182 | } |
183 | |
184 | TEST(DataTest, NoSequencerIsIdentity) { |
185 | using namespace torch::data::detail::sequencers; // NOLINT |
186 | NoSequencer<int> no_sequencer; |
187 | const auto value = no_sequencer.next([] { return 5; }).value(); |
188 | ASSERT_EQ(value, 5); |
189 | } |
190 | |
191 | TEST(DataTest, OrderedSequencerIsSetUpWell) { |
192 | using namespace torch::data::detail::sequencers; // NOLINT |
193 | struct S { |
194 | size_t sequence_number; |
195 | }; |
196 | const size_t kMaxJobs = 5; |
197 | OrderedSequencer<S> sequencer(kMaxJobs); |
198 | ASSERT_EQ(sequencer.next_sequence_number_, 0); |
199 | ASSERT_EQ(sequencer.buffer_.size(), kMaxJobs); |
200 | } |
201 | |
202 | TEST(DataTest, OrderedSequencerReOrdersValues) { |
203 | using namespace torch::data::detail::sequencers; // NOLINT |
204 | struct S { |
205 | size_t sequence_number; |
206 | }; |
207 | const size_t kMaxJobs = 5; |
208 | OrderedSequencer<S> sequencer(kMaxJobs); |
209 | |
210 | std::vector<size_t> v = {0, 2, 4, 3, 1}; |
211 | size_t index = 0; |
212 | auto getter = [&v, &index]() { return S{v.at(index++)}; }; |
213 | |
214 | // Let's say the sequence number matches for the batch one, then it should |
215 | // return immediately. |
216 | const auto batch = sequencer.next(getter); |
217 | ASSERT_EQ(batch.value().sequence_number, 0); |
218 | ASSERT_EQ(index, 1); |
219 | |
220 | // Now it should call the getter until it gets the next value. |
221 | ASSERT_EQ(1, sequencer.next(getter).value().sequence_number); |
222 | ASSERT_EQ(index, 5); |
223 | |
224 | // The next three should come in order. |
225 | for (size_t i = 2; i <= 4; ++i) { |
226 | // New value doesn't matter. In fact, it shouldn't be accessed. |
227 | ASSERT_EQ(i, sequencer.next(getter).value().sequence_number); |
228 | // The index doesn't change. |
229 | ASSERT_EQ(index, 5); |
230 | } |
231 | } |
232 | |
233 | TEST(DataTest, BatchLambdaAppliesFunctionToBatch) { |
234 | using InputBatch = std::vector<int>; |
235 | using OutputBatch = std::string; |
236 | DummyDataset d; |
237 | auto e = d.map(transforms::BatchLambda<InputBatch, OutputBatch>( |
238 | [](std::vector<int> input) { |
239 | return std::to_string(std::accumulate(input.begin(), input.end(), 0)); |
240 | })); |
241 | ASSERT_EQ(e.get_batch({1, 2, 3, 4, 5}), std::string("20" )); |
242 | } |
243 | |
244 | TEST(DataTest, LambdaAppliesFunctionToExample) { |
245 | auto d = DummyDataset().map(transforms::Lambda<int, std::string>( |
246 | static_cast<std::string (*)(int)>(std::to_string))); |
247 | std::vector<std::string> expected = {"1" , "2" , "3" , "4" , "5" }; |
248 | ASSERT_EQ(d.get_batch({0, 1, 2, 3, 4}), expected); |
249 | } |
250 | |
251 | TEST(DataTest, CollateReducesBatch) { |
252 | auto d = |
253 | DummyDataset().map(transforms::Collate<int>([](std::vector<int> input) { |
254 | return std::accumulate(input.begin(), input.end(), 0); |
255 | })); |
256 | ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20); |
257 | } |
258 | |
259 | TEST(DataTest, CollationReducesBatch) { |
260 | struct Summer : transforms::Collation<int> { |
261 | int apply_batch(std::vector<int> input) override { |
262 | return std::accumulate(input.begin(), input.end(), 0); |
263 | } |
264 | }; |
265 | auto d = DummyDataset().map(Summer{}); |
266 | ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20); |
267 | } |
268 | |
269 | TEST(DataTest, SequentialSamplerReturnsIndicesInOrder) { |
270 | samplers::SequentialSampler sampler(10); |
271 | ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2})); |
272 | ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({3, 4, 5, 6, 7})); |
273 | ASSERT_EQ(sampler.next(2).value(), std::vector<size_t>({8, 9})); |
274 | ASSERT_FALSE(sampler.next(2).has_value()); |
275 | } |
276 | |
277 | TEST(DataTest, SequentialSamplerReturnsLessValuesForLastBatch) { |
278 | samplers::SequentialSampler sampler(5); |
279 | ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2})); |
280 | ASSERT_EQ(sampler.next(100).value(), std::vector<size_t>({3, 4})); |
281 | ASSERT_FALSE(sampler.next(2).has_value()); |
282 | } |
283 | |
284 | TEST(DataTest, SequentialSamplerResetsWell) { |
285 | samplers::SequentialSampler sampler(5); |
286 | ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4})); |
287 | ASSERT_FALSE(sampler.next(2).has_value()); |
288 | sampler.reset(); |
289 | ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4})); |
290 | ASSERT_FALSE(sampler.next(2).has_value()); |
291 | } |
292 | |
293 | TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) { |
294 | samplers::SequentialSampler sampler(5); |
295 | ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4})); |
296 | ASSERT_FALSE(sampler.next(2).has_value()); |
297 | sampler.reset(7); |
298 | ASSERT_EQ( |
299 | sampler.next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6})); |
300 | ASSERT_FALSE(sampler.next(2).has_value()); |
301 | sampler.reset(3); |
302 | ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2})); |
303 | ASSERT_FALSE(sampler.next(2).has_value()); |
304 | } |
305 | |
306 | TEST(DataTest, CanSaveAndLoadSequentialSampler) { |
307 | { |
308 | samplers::SequentialSampler a(10); |
309 | ASSERT_EQ(a.index(), 0); |
310 | std::stringstream stream; |
311 | torch::save(a, stream); |
312 | |
313 | samplers::SequentialSampler b(10); |
314 | torch::load(b, stream); |
315 | ASSERT_EQ(b.index(), 0); |
316 | } |
317 | { |
318 | samplers::SequentialSampler a(10); |
319 | a.next(3); |
320 | a.next(4); |
321 | ASSERT_EQ(a.index(), 7); |
322 | std::stringstream stream; |
323 | torch::save(a, stream); |
324 | |
325 | samplers::SequentialSampler b(10); |
326 | torch::load(b, stream); |
327 | ASSERT_EQ(b.index(), 7); |
328 | } |
329 | } |
330 | |
331 | TEST(DataTest, RandomSamplerReturnsIndicesInCorrectRange) { |
332 | samplers::RandomSampler sampler(10); |
333 | |
334 | std::vector<size_t> indices = sampler.next(3).value(); |
335 | for (auto i : indices) { |
336 | ASSERT_GE(i, 0); |
337 | ASSERT_LT(i, 10); |
338 | } |
339 | |
340 | indices = sampler.next(5).value(); |
341 | for (auto i : indices) { |
342 | ASSERT_GE(i, 0); |
343 | ASSERT_LT(i, 10); |
344 | } |
345 | |
346 | indices = sampler.next(2).value(); |
347 | for (auto i : indices) { |
348 | ASSERT_GE(i, 0); |
349 | ASSERT_LT(i, 10); |
350 | } |
351 | |
352 | ASSERT_FALSE(sampler.next(10).has_value()); |
353 | } |
354 | |
355 | TEST(DataTest, RandomSamplerReturnsLessValuesForLastBatch) { |
356 | samplers::RandomSampler sampler(5); |
357 | ASSERT_EQ(sampler.next(3).value().size(), 3); |
358 | ASSERT_EQ(sampler.next(100).value().size(), 2); |
359 | ASSERT_FALSE(sampler.next(2).has_value()); |
360 | } |
361 | |
362 | TEST(DataTest, RandomSamplerResetsWell) { |
363 | samplers::RandomSampler sampler(5); |
364 | ASSERT_EQ(sampler.next(5).value().size(), 5); |
365 | ASSERT_FALSE(sampler.next(2).has_value()); |
366 | sampler.reset(); |
367 | ASSERT_EQ(sampler.next(5).value().size(), 5); |
368 | ASSERT_FALSE(sampler.next(2).has_value()); |
369 | } |
370 | |
371 | TEST(DataTest, RandomSamplerResetsWithNewSizeWell) { |
372 | samplers::RandomSampler sampler(5); |
373 | ASSERT_EQ(sampler.next(5).value().size(), 5); |
374 | ASSERT_FALSE(sampler.next(2).has_value()); |
375 | sampler.reset(7); |
376 | ASSERT_EQ(sampler.next(7).value().size(), 7); |
377 | ASSERT_FALSE(sampler.next(2).has_value()); |
378 | sampler.reset(3); |
379 | ASSERT_EQ(sampler.next(3).value().size(), 3); |
380 | ASSERT_FALSE(sampler.next(2).has_value()); |
381 | } |
382 | |
383 | TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) { |
384 | { |
385 | samplers::RandomSampler a(10); |
386 | |
387 | std::stringstream stream; |
388 | torch::save(a, stream); |
389 | |
390 | samplers::RandomSampler b(10); |
391 | torch::load(b, stream); |
392 | |
393 | ASSERT_EQ(a.next(10).value(), b.next(10).value()); |
394 | } |
395 | { |
396 | samplers::RandomSampler a(10); |
397 | a.next(3); |
398 | ASSERT_EQ(a.index(), 3); |
399 | |
400 | std::stringstream stream; |
401 | torch::save(a, stream); |
402 | |
403 | samplers::RandomSampler b(10); |
404 | torch::load(b, stream); |
405 | ASSERT_EQ(b.index(), 3); |
406 | |
407 | auto b_sequence = b.next(10).value(); |
408 | ASSERT_EQ(b_sequence.size(), 7); |
409 | ASSERT_EQ(a.next(10).value(), b_sequence); |
410 | } |
411 | } |
412 | |
413 | TEST(DataTest, StreamSamplerReturnsTheBatchSizeAndThenRemainder) { |
414 | samplers::StreamSampler sampler(/*epoch_size=*/100); |
415 | ASSERT_EQ(sampler.next(10).value(), 10); |
416 | ASSERT_EQ(sampler.next(2).value(), 2); |
417 | ASSERT_EQ(sampler.next(85).value(), 85); |
418 | ASSERT_EQ(sampler.next(123).value(), 3); |
419 | ASSERT_FALSE(sampler.next(1).has_value()); |
420 | } |
421 | |
422 | TEST(DataTest, StreamSamplerResetsWell) { |
423 | samplers::StreamSampler sampler(/*epoch_size=*/5); |
424 | ASSERT_EQ(sampler.next(5).value().size(), 5); |
425 | ASSERT_FALSE(sampler.next(2).has_value()); |
426 | sampler.reset(); |
427 | ASSERT_EQ(sampler.next(5).value().size(), 5); |
428 | ASSERT_FALSE(sampler.next(2).has_value()); |
429 | } |
430 | |
431 | TEST(DataTest, StreamSamplerResetsWithNewSizeWell) { |
432 | samplers::StreamSampler sampler(/*epoch_size=*/5); |
433 | ASSERT_EQ(sampler.next(5).value().size(), 5); |
434 | ASSERT_FALSE(sampler.next(2).has_value()); |
435 | sampler.reset(7); |
436 | ASSERT_EQ(sampler.next(7).value().size(), 7); |
437 | ASSERT_FALSE(sampler.next(2).has_value()); |
438 | sampler.reset(3); |
439 | ASSERT_EQ(sampler.next(3).value().size(), 3); |
440 | ASSERT_FALSE(sampler.next(2).has_value()); |
441 | } |
442 | |
443 | TEST(DataTest, TensorDatasetConstructsFromSingleTensor) { |
444 | datasets::TensorDataset dataset(torch::eye(5)); |
445 | ASSERT_TRUE( |
446 | torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2))); |
447 | } |
448 | |
449 | TEST(DataTest, TensorDatasetConstructsFromInitializerListOfTensors) { |
450 | std::vector<torch::Tensor> vector = torch::eye(5).chunk(5); |
451 | datasets::TensorDataset dataset(vector); |
452 | ASSERT_TRUE( |
453 | torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2))); |
454 | } |
455 | |
456 | TEST(DataTest, StackTransformWorksForExample) { |
457 | struct D : public datasets::Dataset<D> { |
458 | Example<> get(size_t index) override { |
459 | return {tensor[index], 1 + tensor[index]}; |
460 | } |
461 | |
462 | torch::optional<size_t> size() const override { |
463 | return tensor.size(0); |
464 | } |
465 | |
466 | torch::Tensor tensor{torch::eye(4)}; |
467 | }; |
468 | |
469 | auto d = D().map(transforms::Stack<Example<>>()); |
470 | |
471 | Example<> batch = d.get_batch({0, 1}); |
472 | ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2))); |
473 | ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2))); |
474 | |
475 | Example<> second = d.get_batch({2, 3}); |
476 | ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4))); |
477 | ASSERT_TRUE(second.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 2, 4))); |
478 | } |
479 | |
480 | TEST(DataTest, StackTransformWorksForTensorExample) { |
481 | auto d = datasets::TensorDataset(torch::eye(4)) |
482 | .map(transforms::Stack<TensorExample>()); |
483 | |
484 | TensorExample batch = d.get_batch({0, 1}); |
485 | ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2))); |
486 | |
487 | TensorExample second = d.get_batch({2, 3}); |
488 | ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4))); |
489 | } |
490 | |
491 | // Template classes cannot be nested in functions. |
492 | template <typename Target> |
493 | struct T : transforms::TensorTransform<Target> { |
494 | torch::Tensor operator()(torch::Tensor input) override { |
495 | return input * 2; |
496 | } |
497 | }; |
498 | |
499 | struct TensorStringDataset |
500 | : datasets:: |
501 | Dataset<TensorStringDataset, Example<torch::Tensor, std::string>> { |
502 | Example<torch::Tensor, std::string> get(size_t index) override { |
503 | return {torch::tensor(static_cast<double>(index)), std::to_string(index)}; |
504 | } |
505 | |
506 | torch::optional<size_t> size() const override { |
507 | return 100; |
508 | } |
509 | }; |
510 | |
511 | TEST(DataTest, TensorTransformWorksForAnyTargetType) { |
512 | auto d = TensorStringDataset().map(T<std::string>{}); |
513 | std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2}); |
514 | |
515 | ASSERT_EQ(batch.size(), 2); |
516 | ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0))); |
517 | ASSERT_EQ(batch[0].target, "1" ); |
518 | |
519 | ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0))); |
520 | ASSERT_EQ(batch[1].target, "2" ); |
521 | } |
522 | |
523 | TEST(DataTest, TensorLambdaWorksforAnyTargetType) { |
524 | auto d = TensorStringDataset().map(transforms::TensorLambda<std::string>( |
525 | [](torch::Tensor input) { return input * 2; })); |
526 | std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2}); |
527 | |
528 | ASSERT_EQ(batch.size(), 2); |
529 | ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0))); |
530 | ASSERT_EQ(batch[0].target, "1" ); |
531 | |
532 | ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0))); |
533 | ASSERT_EQ(batch[1].target, "2" ); |
534 | } |
535 | |
536 | struct DummyTensorDataset |
537 | : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> { |
538 | Example<torch::Tensor, int> get(size_t index) override { |
539 | const auto channels = static_cast<int64_t>(index); |
540 | torch::Tensor tensor = |
541 | (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4}); |
542 | return {tensor, static_cast<int>(channels)}; |
543 | } |
544 | |
545 | torch::optional<size_t> size() const override { |
546 | return 100; |
547 | } |
548 | }; |
549 | |
550 | TEST(DataTest, NormalizeTransform) { |
551 | auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1)); |
552 | |
553 | // Works for zero (one implicit) channels |
554 | std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0); |
555 | ASSERT_EQ(output.size(), 1); |
556 | // (1 - 0.5) / 0.1 = 5 |
557 | ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5)) |
558 | << output[0].data; |
559 | |
560 | // Works for one explicit channel |
561 | output = dataset.get_batch(1); |
562 | ASSERT_EQ(output.size(), 1); |
563 | ASSERT_EQ(output[0].data.size(0), 1); |
564 | ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5)) |
565 | << output[0].data; |
566 | |
567 | // Works for two channels with different moments |
568 | dataset = DummyTensorDataset().map( |
569 | transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2})); |
570 | output = dataset.get_batch(2); |
571 | ASSERT_EQ(output.size(), 1); |
572 | ASSERT_EQ(output[0].data.size(0), 2); |
573 | ASSERT_TRUE(output[0] |
574 | .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1) |
575 | .allclose(torch::ones({1, 4, 4}) * 5)) |
576 | << output[0].data; |
577 | ASSERT_TRUE(output[0] |
578 | .data.slice(/*dim=*/0, /*start=*/1) |
579 | .allclose(torch::ones({1, 4, 4}) * -2.5)) |
580 | << output[0].data; |
581 | |
582 | // Works for three channels with one moment value |
583 | dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2)); |
584 | output = dataset.get_batch(3); |
585 | ASSERT_EQ(output.size(), 1); |
586 | ASSERT_EQ(output[0].data.size(0), 3); |
587 | ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5)) |
588 | << output[0].data; |
589 | |
590 | // Works for three channels with different moments |
591 | dataset = DummyTensorDataset().map( |
592 | transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2})); |
593 | output = dataset.get_batch(3); |
594 | ASSERT_EQ(output.size(), 1); |
595 | ASSERT_EQ(output[0].data.size(0), 3); |
596 | ASSERT_TRUE(output[0] |
597 | .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1) |
598 | .allclose(torch::ones({1, 4, 4}) * 5)) |
599 | << output[0].data; |
600 | ASSERT_TRUE(output[0] |
601 | .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2) |
602 | .allclose(torch::ones({1, 4, 4}) * -2.5)) |
603 | << output[0].data; |
604 | ASSERT_TRUE(output[0] |
605 | .data.slice(/*dim=*/0, /*start=*/2) |
606 | .allclose(torch::ones({1, 4, 4}) * 12.5)) |
607 | << output[0].data; |
608 | } |
609 | |
610 | struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> { |
611 | UnCopyableDataset() = default; |
612 | |
613 | UnCopyableDataset(const UnCopyableDataset&) = delete; |
614 | UnCopyableDataset& operator=(const UnCopyableDataset&) = delete; |
615 | |
616 | UnCopyableDataset(UnCopyableDataset&&) = default; |
617 | UnCopyableDataset& operator=(UnCopyableDataset&&) = default; |
618 | |
619 | // NOLINTNEXTLINE(modernize-use-override) |
620 | ~UnCopyableDataset() = default; |
621 | |
622 | Example<> get(size_t index) override { |
623 | return { |
624 | torch::tensor({static_cast<int64_t>(index)}), |
625 | torch::tensor({static_cast<int64_t>(index)})}; |
626 | } |
627 | |
628 | torch::optional<size_t> size() const override { |
629 | return 100; |
630 | } |
631 | }; |
632 | |
633 | TEST(DataTest, MapDoesNotCopy) { |
634 | auto dataset = UnCopyableDataset() |
635 | .map(transforms::TensorLambda<>( |
636 | [](torch::Tensor tensor) { return tensor + 1; })) |
637 | .map(transforms::TensorLambda<>( |
638 | [](torch::Tensor tensor) { return tensor + 2; })) |
639 | .map(transforms::TensorLambda<>( |
640 | [](torch::Tensor tensor) { return tensor + 3; })); |
641 | |
642 | auto data = dataset.get_batch(1).at(0).data; |
643 | ASSERT_EQ(data.numel(), 1); |
644 | ASSERT_EQ(data[0].item<float>(), 7); |
645 | } |
646 | |
647 | TEST(DataTest, QueuePushAndPopFromSameThread) { |
648 | torch::data::detail::Queue<int> queue; |
649 | queue.push(1); |
650 | queue.push(2); |
651 | ASSERT_EQ(queue.pop(), 1); |
652 | ASSERT_EQ(queue.pop(), 2); |
653 | } |
654 | |
655 | TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) { |
656 | torch::data::detail::Queue<int> queue; |
657 | ASSERT_THROWS_WITH( |
658 | queue.pop(10 * kMillisecond), |
659 | "Timeout in DataLoader queue while waiting for next batch " |
660 | "(timeout was 10 ms)" ); |
661 | } |
662 | |
663 | TEST(DataTest, QueuePushAndPopFromDifferentThreads) { |
664 | using torch::data::detail::Queue; |
665 | |
666 | // First test: push batch and the pop in thread. |
667 | { |
668 | Queue<int> queue; |
669 | queue.push(1); |
670 | auto future = |
671 | std::async(std::launch::async, [&queue] { return queue.pop(); }); |
672 | ASSERT_EQ(future.get(), 1); |
673 | } |
674 | |
675 | // Second test: attempt to pop batch (and block), then push. |
676 | { |
677 | Queue<int> queue; |
678 | std::thread thread([&queue] { |
679 | std::this_thread::sleep_for(20 * kMillisecond); |
680 | queue.push(123); |
681 | }); |
682 | ASSERT_EQ(queue.pop(), 123); |
683 | thread.join(); |
684 | } |
685 | } |
686 | |
687 | TEST(DataTest, QueueClearEmptiesTheQueue) { |
688 | torch::data::detail::Queue<int> queue; |
689 | queue.push(1); |
690 | queue.push(2); |
691 | queue.push(3); |
692 | ASSERT_EQ(queue.clear(), 3); |
693 | ASSERT_THROWS_WITH(queue.pop(1 * kMillisecond), "Timeout" ); |
694 | } |
695 | |
696 | TEST(DataTest, DataShuttleCanPushAndPopJob) { |
697 | torch::data::detail::DataShuttle<int, int> shuttle; |
698 | shuttle.push_job(1); |
699 | shuttle.push_job(2); |
700 | ASSERT_EQ(shuttle.pop_job(), 1); |
701 | ASSERT_EQ(shuttle.pop_job(), 2); |
702 | } |
703 | |
704 | TEST(DataTest, DataShuttleCanPushAndPopResult) { |
705 | torch::data::detail::DataShuttle<int, int> shuttle; |
706 | // pop_result() will only attempt to pop if there was a push_job() batch. |
707 | shuttle.push_job(1); |
708 | shuttle.push_job(2); |
709 | |
710 | shuttle.pop_job(); |
711 | shuttle.push_result(1); |
712 | ASSERT_EQ(shuttle.pop_result().value(), 1); |
713 | |
714 | shuttle.pop_job(); |
715 | shuttle.push_result(2); |
716 | ASSERT_EQ(shuttle.pop_result().value(), 2); |
717 | } |
718 | |
719 | TEST(DataTest, DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight) { |
720 | torch::data::detail::DataShuttle<int, int> shuttle; |
721 | ASSERT_FALSE(shuttle.pop_result().has_value()); |
722 | shuttle.push_job(1); |
723 | shuttle.pop_job(); |
724 | shuttle.push_result(1); |
725 | ASSERT_EQ(shuttle.pop_result().value(), 1); |
726 | ASSERT_FALSE(shuttle.pop_result().has_value()); |
727 | ASSERT_FALSE(shuttle.pop_result().has_value()); |
728 | } |
729 | |
730 | TEST(DataTest, DataShuttleDrainMeansPopResultReturnsNullopt) { |
731 | torch::data::detail::DataShuttle<int, int> shuttle; |
732 | shuttle.push_job(1); |
733 | shuttle.push_result(1); |
734 | shuttle.drain(); |
735 | ASSERT_FALSE(shuttle.pop_result().has_value()); |
736 | } |
737 | |
738 | TEST(DataTest, DataShuttlePopResultTimesOut) { |
739 | torch::data::detail::DataShuttle<int, int> shuttle; |
740 | shuttle.push_job(1); |
741 | ASSERT_THROWS_WITH(shuttle.pop_result(10 * kMillisecond), "Timeout" ); |
742 | } |
743 | |
744 | struct UncopyableDataset : datasets::Dataset<UncopyableDataset, int> { |
745 | UncopyableDataset(const std::string& /* unused */) {} |
746 | |
747 | UncopyableDataset(UncopyableDataset&&) = default; |
748 | UncopyableDataset& operator=(UncopyableDataset&&) = default; |
749 | |
750 | UncopyableDataset(const UncopyableDataset&) = delete; |
751 | UncopyableDataset& operator=(const UncopyableDataset&) = delete; |
752 | |
753 | int get(size_t index) override { |
754 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
755 | return 1 + index; |
756 | } |
757 | torch::optional<size_t> size() const override { |
758 | return 100; |
759 | } |
760 | }; |
761 | |
762 | TEST(DataTest, SharedBatchDatasetReallyIsShared) { |
763 | // This test will only compile if we really are not making any copies. |
764 | // There is otherwise no logic to test and because it is not deterministic |
765 | // how many and when worker threads access the shareddataset, we don't have |
766 | // any additional assertions here. |
767 | |
768 | auto shared_dataset = |
769 | torch::data::datasets::make_shared_dataset<UncopyableDataset>( |
770 | "uncopyable" ); |
771 | |
772 | auto data_loader = torch::data::make_data_loader( |
773 | shared_dataset, torch::data::DataLoaderOptions().workers(3)); |
774 | |
775 | for (auto batch : *data_loader) { |
776 | /* exhaust */ |
777 | } |
778 | } |
779 | |
780 | TEST(DataTest, SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject) { |
781 | // This will not compile if a copy is made. |
782 | auto shared_dataset = |
783 | torch::data::datasets::make_shared_dataset<UncopyableDataset>( |
784 | UncopyableDataset("uncopyable" )); |
785 | ASSERT_EQ(shared_dataset.size().value(), 100); |
786 | } |
787 | |
788 | struct TestIndex : public torch::data::samplers::CustomBatchRequest { |
789 | explicit TestIndex(size_t offset, std::vector<size_t> index) |
790 | : offset(offset), index(std::move(index)) {} |
791 | size_t size() const override { |
792 | return index.size(); |
793 | } |
794 | size_t offset; |
795 | std::vector<size_t> index; |
796 | }; |
797 | |
798 | struct TestIndexDataset |
799 | : datasets::BatchDataset<TestIndexDataset, std::vector<int>, TestIndex> { |
800 | explicit TestIndexDataset(size_t size) : data(size) { |
801 | std::iota(data.begin(), data.end(), size_t(0)); |
802 | } |
803 | std::vector<int> get_batch(TestIndex index) override { |
804 | std::vector<int> batch; |
805 | for (auto i : index.index) { |
806 | batch.push_back(index.offset + data.at(i)); |
807 | } |
808 | return batch; |
809 | } |
810 | torch::optional<size_t> size() const override { |
811 | return data.size(); |
812 | } |
813 | std::vector<int> data; |
814 | }; |
815 | |
816 | struct TestIndexSampler : public samplers::Sampler<TestIndex> { |
817 | explicit TestIndexSampler(size_t size) : size_(size) {} |
818 | void reset(torch::optional<size_t> new_size = torch::nullopt) override {} |
819 | torch::optional<TestIndex> next(size_t batch_size) override { |
820 | if (index_ >= size_) { |
821 | return torch::nullopt; |
822 | } |
823 | std::vector<size_t> indices(batch_size); |
824 | std::iota(indices.begin(), indices.end(), size_t(0)); |
825 | index_ += batch_size; |
826 | return TestIndex(batch_size, std::move(indices)); |
827 | } |
828 | void save(torch::serialize::OutputArchive& archive) const override {} |
829 | void load(torch::serialize::InputArchive& archive) override {} |
830 | size_t index_ = 0; |
831 | size_t size_; |
832 | }; |
833 | |
834 | TEST(DataTest, CanUseCustomTypeAsIndexType) { |
835 | const int kBatchSize = 10; |
836 | auto data_loader = torch::data::make_data_loader( |
837 | TestIndexDataset(23), TestIndexSampler(23), kBatchSize); |
838 | |
839 | for (auto batch : *data_loader) { |
840 | for (const auto j : c10::irange(kBatchSize)) { |
841 | ASSERT_EQ(batch.at(j), 10 + j); |
842 | } |
843 | } |
844 | } |
845 | |
846 | TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) { |
847 | size_t sample_count = 10; |
848 | samplers::DistributedRandomSampler drs(sample_count); |
849 | |
850 | std::vector<size_t> res; |
851 | torch::optional<std::vector<size_t>> idx; |
852 | while ((idx = drs.next(3)).has_value()) { |
853 | res.insert(std::end(res), std::begin(*idx), std::end(*idx)); |
854 | } |
855 | |
856 | ASSERT_EQ(res.size(), sample_count); |
857 | |
858 | std::sort(res.begin(), res.end()); |
859 | for (const auto i : c10::irange(res.size())) { |
860 | ASSERT_EQ(res[i], i); |
861 | } |
862 | } |
863 | |
864 | TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) { |
865 | size_t sample_count = 10; |
866 | size_t num_replicas = 3; |
867 | |
868 | auto test_function = [&](bool allow_duplicates, |
869 | size_t local_sample_count, |
870 | std::vector<size_t>& output, |
871 | size_t batch_size) { |
872 | std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers; |
873 | |
874 | for (const auto i : c10::irange(num_replicas)) { |
875 | samplers.emplace_back( |
876 | torch::make_unique<samplers::DistributedRandomSampler>( |
877 | sample_count, num_replicas, i, allow_duplicates)); |
878 | } |
879 | |
880 | std::vector<size_t> res; |
881 | for (const auto i : c10::irange(num_replicas)) { |
882 | (*samplers[i]).reset(); |
883 | torch::optional<std::vector<size_t>> idx; |
884 | while ((idx = (*samplers[i]).next(batch_size)).has_value()) { |
885 | res.insert(std::end(res), std::begin(*idx), std::end(*idx)); |
886 | } |
887 | ASSERT_EQ(res.size(), local_sample_count * (i + 1)); |
888 | } |
889 | std::sort(res.begin(), res.end()); |
890 | ASSERT_EQ(res, output); |
891 | }; |
892 | |
893 | for (size_t batch_size = 1; batch_size <= 3; ++batch_size) { |
894 | size_t local_sample_count = |
895 | static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas)); |
896 | std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9}; |
897 | test_function(true, local_sample_count, output1, batch_size); |
898 | |
899 | local_sample_count = |
900 | static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas)); |
901 | std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8}; |
902 | test_function(false, local_sample_count, output2, batch_size); |
903 | } |
904 | } |
905 | |
906 | TEST(DataTest, CanSaveAndLoadDistributedRandomSampler) { |
907 | { |
908 | samplers::DistributedRandomSampler a(10); |
909 | ASSERT_EQ(a.index(), 0); |
910 | std::stringstream stream; |
911 | torch::save(a, stream); |
912 | |
913 | samplers::DistributedRandomSampler b(10); |
914 | torch::load(b, stream); |
915 | ASSERT_EQ(b.index(), 0); |
916 | } |
917 | { |
918 | samplers::DistributedRandomSampler a(10); |
919 | a.next(3); |
920 | a.next(4); |
921 | ASSERT_EQ(a.index(), 7); |
922 | std::stringstream stream; |
923 | torch::save(a, stream); |
924 | |
925 | samplers::DistributedRandomSampler b(10); |
926 | torch::load(b, stream); |
927 | ASSERT_EQ(b.index(), 7); |
928 | } |
929 | { |
930 | samplers::DistributedRandomSampler a(10); |
931 | a.set_epoch(3); |
932 | std::stringstream stream; |
933 | torch::save(a, stream); |
934 | |
935 | samplers::DistributedRandomSampler b(10); |
936 | torch::load(b, stream); |
937 | ASSERT_EQ(b.epoch(), 3); |
938 | } |
939 | } |
940 | |
941 | TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) { |
942 | size_t sample_count = 10; |
943 | size_t batch_size = 3; |
944 | samplers::DistributedSequentialSampler dss(sample_count); |
945 | |
946 | std::vector<size_t> res; |
947 | torch::optional<std::vector<size_t>> idx; |
948 | while ((idx = dss.next(batch_size)).has_value()) { |
949 | res.insert(std::end(res), std::begin(*idx), std::end(*idx)); |
950 | } |
951 | |
952 | ASSERT_EQ(res.size(), sample_count); |
953 | |
954 | std::sort(res.begin(), res.end()); |
955 | for (const auto i : c10::irange(res.size())) { |
956 | ASSERT_EQ(res[i], i); |
957 | } |
958 | } |
959 | |
960 | TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) { |
961 | size_t sample_count = 10; |
962 | size_t num_replicas = 3; |
963 | |
964 | auto test_function = [&](bool allow_duplicates, |
965 | size_t local_sample_count, |
966 | std::vector<size_t>& output, |
967 | size_t batch_size) { |
968 | std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>> |
969 | samplers; |
970 | |
971 | for (const auto i : c10::irange(num_replicas)) { |
972 | samplers.emplace_back( |
973 | torch::make_unique<samplers::DistributedSequentialSampler>( |
974 | sample_count, num_replicas, i, allow_duplicates)); |
975 | } |
976 | |
977 | std::vector<size_t> res; |
978 | for (const auto i : c10::irange(num_replicas)) { |
979 | (*samplers[i]).reset(); |
980 | torch::optional<std::vector<size_t>> idx; |
981 | while ((idx = (*samplers[i]).next(batch_size)).has_value()) { |
982 | res.insert(std::end(res), std::begin(*idx), std::end(*idx)); |
983 | } |
984 | ASSERT_EQ(res.size(), local_sample_count * (i + 1)); |
985 | } |
986 | std::sort(res.begin(), res.end()); |
987 | ASSERT_EQ(res, output); |
988 | }; |
989 | |
990 | for (size_t batch_size = 1; batch_size <= 3; ++batch_size) { |
991 | size_t local_sample_count = |
992 | static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas)); |
993 | std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9}; |
994 | test_function(true, local_sample_count, output1, batch_size); |
995 | |
996 | local_sample_count = |
997 | static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas)); |
998 | std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8}; |
999 | test_function(false, local_sample_count, output2, batch_size); |
1000 | } |
1001 | } |
1002 | |
1003 | TEST(DataTest, CanSaveAndLoadDistributedSequentialSampler) { |
1004 | { |
1005 | samplers::DistributedSequentialSampler a(10); |
1006 | ASSERT_EQ(a.index(), 0); |
1007 | std::stringstream stream; |
1008 | torch::save(a, stream); |
1009 | |
1010 | samplers::DistributedSequentialSampler b(10); |
1011 | torch::load(b, stream); |
1012 | ASSERT_EQ(b.index(), 0); |
1013 | } |
1014 | { |
1015 | samplers::DistributedSequentialSampler a(10); |
1016 | a.next(3); |
1017 | a.next(4); |
1018 | ASSERT_EQ(a.index(), 7); |
1019 | std::stringstream stream; |
1020 | torch::save(a, stream); |
1021 | |
1022 | samplers::DistributedSequentialSampler b(10); |
1023 | torch::load(b, stream); |
1024 | ASSERT_EQ(b.index(), 7); |
1025 | } |
1026 | } |
1027 | |
1028 | TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) { |
1029 | DataLoaderOptions partial_options; |
1030 | FullDataLoaderOptions full_options(partial_options); |
1031 | ASSERT_EQ(full_options.batch_size, 1); |
1032 | ASSERT_FALSE(full_options.drop_last); |
1033 | ASSERT_EQ(full_options.workers, 0); |
1034 | ASSERT_EQ(full_options.max_jobs, 0); |
1035 | ASSERT_FALSE(full_options.timeout.has_value()); |
1036 | ASSERT_TRUE(full_options.enforce_ordering); |
1037 | } |
1038 | |
1039 | TEST(DataLoaderTest, DataLoaderOptionsCoalesceOptionalValues) { |
1040 | auto partial_options = DataLoaderOptions(32).workers(10); |
1041 | FullDataLoaderOptions full_options(partial_options); |
1042 | ASSERT_EQ(full_options.batch_size, 32); |
1043 | ASSERT_EQ(full_options.max_jobs, 2 * 10); |
1044 | } |
1045 | |
1046 | TEST(DataLoaderTest, MakeDataLoaderDefaultsAsExpected) { |
1047 | auto data_loader = torch::data::make_data_loader( |
1048 | DummyDataset().map(transforms::Lambda<int>([](int x) { return x + 1; }))); |
1049 | ASSERT_EQ(data_loader->options().batch_size, 1); |
1050 | } |
1051 | |
1052 | struct UnsizedDataset : public datasets::Dataset<UnsizedDataset> { |
1053 | torch::data::Example<> get(size_t i) override { |
1054 | return {torch::ones(i), torch::ones(i)}; |
1055 | } |
1056 | torch::optional<size_t> size() const noexcept override { |
1057 | return torch::nullopt; |
1058 | } |
1059 | }; |
1060 | |
1061 | TEST( |
1062 | DataLoaderTest, |
1063 | MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset) { |
1064 | ASSERT_THROWS_WITH( |
1065 | torch::data::make_data_loader(UnsizedDataset{}), |
1066 | "Expected the dataset to be sized in order to construct the Sampler" ); |
1067 | } |
1068 | |
1069 | TEST(DataLoaderTest, IteratorsCompareEqualToThemselves) { |
1070 | auto data_loader = torch::data::make_data_loader(DummyDataset(), 32); |
1071 | auto begin = data_loader->begin(); |
1072 | ASSERT_EQ(begin, begin); |
1073 | auto end = data_loader->end(); |
1074 | ASSERT_EQ(end, end); |
1075 | } |
1076 | |
1077 | TEST(DataLoaderTest, ValidIteratorsCompareUnequalToEachOther) { |
1078 | auto data_loader = torch::data::make_data_loader(DummyDataset(), 32); |
1079 | auto i = data_loader->begin(); |
1080 | auto j = data_loader->begin(); |
1081 | ASSERT_NE(i, j); |
1082 | ++j; |
1083 | ASSERT_NE(i, j); |
1084 | } |
1085 | |
1086 | TEST(DataLoaderTest, SentinelIteratorsCompareEqualToEachOther) { |
1087 | auto data_loader = torch::data::make_data_loader(DummyDataset(), 32); |
1088 | auto i = data_loader->end(); |
1089 | auto j = data_loader->end(); |
1090 | ASSERT_EQ(i, j); |
1091 | } |
1092 | |
1093 | TEST(DataLoaderTest, IteratorsCompareEqualToSentinelWhenExhausted) { |
1094 | DummyDataset dataset; |
1095 | auto data_loader = |
1096 | torch::data::make_data_loader(dataset, dataset.size().value() / 4); |
1097 | auto i = data_loader->begin(); |
1098 | auto end = data_loader->end(); |
1099 | ASSERT_NE(i, end); |
1100 | ++i; |
1101 | ASSERT_NE(i, end); |
1102 | ++i; |
1103 | ASSERT_NE(i, end); |
1104 | ++i; |
1105 | ASSERT_NE(i, end); |
1106 | ++i; |
1107 | ASSERT_EQ(i, end); |
1108 | } |
1109 | |
1110 | TEST(DataLoaderTest, IteratorsShareState) { |
1111 | DummyDataset dataset; |
1112 | auto data_loader = |
1113 | torch::data::make_data_loader(dataset, dataset.size().value() / 2); |
1114 | auto i = data_loader->begin(); |
1115 | auto j = i; |
1116 | auto end = data_loader->end(); |
1117 | ASSERT_NE(i, end); |
1118 | ASSERT_NE(j, end); |
1119 | ++i; |
1120 | ASSERT_NE(i, end); |
1121 | ASSERT_NE(j, end); |
1122 | ++j; |
1123 | ASSERT_EQ(i, end); |
1124 | ASSERT_EQ(j, end); |
1125 | } |
1126 | |
1127 | TEST(DataLoaderTest, CanDereferenceIteratorMultipleTimes) { |
1128 | DummyDataset dataset; |
1129 | auto data_loader = |
1130 | torch::data::make_data_loader<torch::data::samplers::SequentialSampler>( |
1131 | dataset, |
1132 | // NOLINTNEXTLINE(bugprone-argument-comment) |
1133 | /*batch_size=*/1); |
1134 | auto iterator = data_loader->begin(); |
1135 | std::vector<int> expected = {1}; |
1136 | ASSERT_EQ(*iterator, expected); |
1137 | ASSERT_EQ(*iterator, expected); |
1138 | ++iterator; |
1139 | expected[0] = 2; |
1140 | ASSERT_EQ(*iterator, expected); |
1141 | ASSERT_EQ(*iterator, expected); |
1142 | ++iterator; |
1143 | expected[0] = 3; |
1144 | ASSERT_EQ(*iterator, expected); |
1145 | ASSERT_EQ(*iterator, expected); |
1146 | } |
1147 | |
1148 | TEST(DataLoaderTest, CanUseIteratorAlgorithms) { |
1149 | struct D : datasets::BatchDataset<D, int> { |
1150 | int get_batch(torch::ArrayRef<size_t> indices) override { |
1151 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
1152 | return 1 + indices.front(); |
1153 | } |
1154 | torch::optional<size_t> size() const override { |
1155 | return 10; |
1156 | } |
1157 | }; |
1158 | |
1159 | D dataset; |
1160 | auto data_loader = |
1161 | torch::data::make_data_loader<torch::data::samplers::SequentialSampler>( |
1162 | dataset, 1); |
1163 | std::vector<int> values; |
1164 | std::copy( |
1165 | data_loader->begin(), data_loader->end(), std::back_inserter(values)); |
1166 | std::vector<int> expected(dataset.size().value()); |
1167 | std::iota(expected.begin(), expected.end(), size_t(1)); |
1168 | ASSERT_EQ(values, expected); |
1169 | } |
1170 | |
1171 | TEST(DataLoaderTest, CallingBeginWhileOtherIteratorIsInFlightThrows) { |
1172 | DummyDataset dataset; |
1173 | auto data_loader = |
1174 | torch::data::make_data_loader(dataset, DataLoaderOptions(1).workers(2)); |
1175 | auto i = data_loader->begin(); |
1176 | ASSERT_THROWS_WITH( |
1177 | data_loader->begin(), |
1178 | "Attempted to get a new DataLoader iterator " |
1179 | "while another iterator is not yet exhausted" ); |
1180 | } |
1181 | |
1182 | TEST(DataLoaderTest, IncrementingExhaustedValidIteratorThrows) { |
1183 | DummyDataset dataset; |
1184 | auto data_loader = |
1185 | torch::data::make_data_loader(dataset, dataset.size().value()); |
1186 | auto i = data_loader->begin(); |
1187 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
1188 | ASSERT_NO_THROW(++i); |
1189 | ASSERT_THROWS_WITH(++i, "Attempted to increment iterator past the end" ); |
1190 | } |
1191 | |
1192 | TEST(DataLoaderTest, DereferencingExhaustedValidIteratorThrows) { |
1193 | DummyDataset dataset; |
1194 | auto data_loader = |
1195 | torch::data::make_data_loader(dataset, dataset.size().value()); |
1196 | auto i = data_loader->begin(); |
1197 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
1198 | ASSERT_NO_THROW(++i); |
1199 | ASSERT_THROWS_WITH( |
1200 | *i, "Attempted to dereference iterator that was past the end" ); |
1201 | } |
1202 | |
1203 | TEST(DataLoaderTest, IncrementingSentinelIteratorThrows) { |
1204 | DummyDataset dataset; |
1205 | auto data_loader = |
1206 | torch::data::make_data_loader(dataset, dataset.size().value()); |
1207 | auto i = data_loader->end(); |
1208 | ASSERT_THROWS_WITH( |
1209 | ++i, |
1210 | "Incrementing the DataLoader's past-the-end iterator is not allowed" ); |
1211 | } |
1212 | |
1213 | TEST(DataLoaderTest, DereferencingSentinelIteratorThrows) { |
1214 | DummyDataset dataset; |
1215 | auto data_loader = |
1216 | torch::data::make_data_loader(dataset, dataset.size().value()); |
1217 | auto i = data_loader->end(); |
1218 | ASSERT_THROWS_WITH( |
1219 | *i, |
1220 | "Dereferencing the DataLoader's past-the-end iterator is not allowed" ); |
1221 | } |
1222 | |
1223 | TEST(DataLoaderTest, YieldsCorrectBatchSize) { |
1224 | DummyDataset dataset; |
1225 | auto data_loader = torch::data::make_data_loader(dataset, 25); |
1226 | auto iterator = data_loader->begin(); |
1227 | ASSERT_EQ(iterator->size(), 25); |
1228 | ASSERT_EQ((++iterator)->size(), 25); |
1229 | ASSERT_EQ((++iterator)->size(), 25); |
1230 | ASSERT_EQ((++iterator)->size(), 25); |
1231 | ASSERT_EQ(++iterator, data_loader->end()); |
1232 | } |
1233 | |
1234 | TEST( |
1235 | DataLoaderTest, |
1236 | ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse) { |
1237 | DummyDataset dataset; |
1238 | auto data_loader = torch::data::make_data_loader( |
1239 | dataset, DataLoaderOptions(33).drop_last(false)); |
1240 | auto iterator = data_loader->begin(); |
1241 | ASSERT_EQ(iterator->size(), 33); |
1242 | ASSERT_EQ((++iterator)->size(), 33); |
1243 | ASSERT_EQ((++iterator)->size(), 33); |
1244 | ASSERT_EQ((++iterator)->size(), 1); |
1245 | ASSERT_EQ(++iterator, data_loader->end()); |
1246 | } |
1247 | |
1248 | TEST( |
1249 | DataLoaderTest, |
1250 | DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue) { |
1251 | DummyDataset dataset; |
1252 | auto data_loader = torch::data::make_data_loader( |
1253 | dataset, DataLoaderOptions(33).drop_last(true)); |
1254 | auto iterator = data_loader->begin(); |
1255 | ASSERT_EQ(iterator->size(), 33); |
1256 | ASSERT_EQ((++iterator)->size(), 33); |
1257 | ASSERT_EQ((++iterator)->size(), 33); |
1258 | ASSERT_EQ(++iterator, data_loader->end()); |
1259 | } |
1260 | |
1261 | TEST(DataLoaderTest, RespectsTimeout) { |
1262 | struct Baton { |
1263 | std::condition_variable cv; |
1264 | std::mutex mutex; |
1265 | }; |
1266 | |
1267 | struct D : datasets::Dataset<DummyDataset, int> { |
1268 | D(std::shared_ptr<Baton> b) : baton(std::move(b)) {} |
1269 | int get(size_t index) override { |
1270 | std::unique_lock<std::mutex> lock(baton->mutex); |
1271 | baton->cv.wait_for(lock, 1000 * kMillisecond); |
1272 | return 0; |
1273 | } |
1274 | torch::optional<size_t> size() const override { |
1275 | return 100; |
1276 | } |
1277 | std::shared_ptr<Baton> baton; |
1278 | }; |
1279 | |
1280 | auto baton = std::make_shared<Baton>(); |
1281 | |
1282 | auto data_loader = torch::data::make_data_loader( |
1283 | D{baton}, DataLoaderOptions().workers(1).timeout(10 * kMillisecond)); |
1284 | |
1285 | auto start = std::chrono::system_clock::now(); |
1286 | |
1287 | ASSERT_THROWS_WITH(*data_loader->begin(), "Timeout" ); |
1288 | baton->cv.notify_one(); |
1289 | |
1290 | auto end = std::chrono::system_clock::now(); |
1291 | auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start); |
1292 | ASSERT_LT(duration.count(), 1); |
1293 | } |
1294 | |
1295 | // stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11 |
1296 | struct Barrier { |
1297 | explicit Barrier(size_t target) : counter_(target) {} |
1298 | void wait() { |
1299 | std::unique_lock<std::mutex> lock(mutex_); |
1300 | if (--counter_ == 0) { |
1301 | cv_.notify_all(); |
1302 | } else { |
1303 | cv_.wait(lock, [this] { return this->counter_ == 0; }); |
1304 | } |
1305 | } |
1306 | |
1307 | size_t counter_; |
1308 | std::condition_variable cv_; |
1309 | std::mutex mutex_; |
1310 | }; |
1311 | |
1312 | // On the OrderingTest: This test is intended to verify that the |
1313 | // `enforce_ordering` option of the dataloader works correctly. The reason this |
1314 | // flag exists is because when the dataloader has multiple workers (threads) |
1315 | // enabled and this flag is not set, the order in which worker threads finish |
1316 | // loading their respective batch and push it back to the dataloader's main |
1317 | // thread (for outside consumption) is not deterministic. Imagine the sampler is |
1318 | // a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index |
1319 | // will be a single "job". Inside the dataloader, worker threads block until a |
1320 | // job is available. It is not deterministic which worker thread wakes up batch |
1321 | // to dequeue a particular batch. Further, some worker threads may take longer |
1322 | // than others to read the data for their index. As such, it could be that |
1323 | // worker thread 2 finishes before all other threads and returns its batch to |
1324 | // the main thread. In that case, the dataloader iterator would return the datum |
1325 | // at index 2 batch, and afterwards the datum from whatever thread finishes |
1326 | // next. As such, the user may see data from indices 2, 0, 3, 1. On another run |
1327 | // of the same dataloader on the same data, threads may be scheduled differently |
1328 | // and return in order 0, 2, 3, 1. To force this ordering to deterministically |
1329 | // be 0, 1, 2, 3, the `enforce_ordering` flag can be set to true. In that case, |
1330 | // the dataloader will use a *sequencer* internally which keeps track of which |
1331 | // datum is expected next, and buffers any other results until that next |
1332 | // expected value arrives. For example, workers 1, 2, 3 may finish before worker |
1333 | // 0. If `enforce_ordering` is true, the sequencer will internally buffer the |
1334 | // results from 1, 2, 3 until worker 0 finishes. Only then does the dataloader |
1335 | // return the datum from worker 0 to the user (and then datum 1 the next time, |
1336 | // then 2 and so on). |
1337 | // |
1338 | // The way the test works is that we start |
1339 | // `kNumberOfWorkers` workers in the dataloader, which each get an index from a |
1340 | // `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread |
1341 | // has a copy of the dataset, and thus `get_batch()` is called on the |
1342 | // thread-local copy in each worker. We want to simulate out-of-order completion |
1343 | // of these threads. For this, we batch set a barrier in the `get_batch()` |
1344 | // method to make sure every worker has some index to fetch assigned. Further, |
1345 | // each worker thread has a unique ID in `0...kNumberOfWorkers-1`. |
1346 | // There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in |
1347 | // which we want the worker threads to return. For this, an iterator into this |
1348 | // order is maintained. When the derferenced iterator (the current order index) |
1349 | // matches the thread ID of a worker, it knows it can now return its index as |
1350 | // well as progress the iterator. Inside the dataloader, the sequencer should |
1351 | // buffer these indices such that they are ultimately returned in order. |
1352 | |
1353 | namespace ordering_test { |
1354 | namespace { |
1355 | const size_t kNumberOfWorkers = 10; |
1356 | const std::vector<size_t> kOrderInWhichWorkersReturnTheirBatch = |
1357 | {3, 7, 0, 5, 4, 8, 2, 1, 9, 6}; |
1358 | } // namespace |
1359 | |
1360 | struct Dataset : datasets::BatchDataset<Dataset, size_t> { |
1361 | Dataset() = default; |
1362 | |
1363 | // This copy constructor will be called when we copy the dataset into a |
1364 | // particular thread. |
1365 | Dataset(const Dataset& other) { |
1366 | static std::atomic<size_t> counter{0}; |
1367 | thread_id_ = counter.fetch_add(1); |
1368 | } |
1369 | |
1370 | Dataset(Dataset&& other) noexcept = default; |
1371 | Dataset& operator=(const Dataset& other) = delete; |
1372 | Dataset& operator=(Dataset&& other) noexcept = delete; |
1373 | |
1374 | size_t get_batch(torch::ArrayRef<size_t> indices) override { |
1375 | static Barrier barrier(kNumberOfWorkers); |
1376 | static auto order_iterator = kOrderInWhichWorkersReturnTheirBatch.begin(); |
1377 | static std::condition_variable cv; |
1378 | static std::mutex mutex; |
1379 | |
1380 | // Wait for all threads to get an index batch and arrive here. |
1381 | barrier.wait(); |
1382 | |
1383 | std::unique_lock<std::mutex> lock(mutex); |
1384 | cv.wait(lock, [this] { return *order_iterator == this->thread_id_; }); |
1385 | ++order_iterator; |
1386 | lock.unlock(); |
1387 | cv.notify_all(); |
1388 | |
1389 | return indices.front(); |
1390 | } |
1391 | |
1392 | torch::optional<size_t> size() const override { |
1393 | return kNumberOfWorkers; |
1394 | } |
1395 | |
1396 | size_t thread_id_ = 0; |
1397 | }; |
1398 | |
1399 | } // namespace ordering_test |
1400 | |
1401 | TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) { |
1402 | auto data_loader = torch::data::make_data_loader( |
1403 | ordering_test::Dataset{}, |
1404 | torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers), |
1405 | DataLoaderOptions() |
1406 | .batch_size(1) |
1407 | .workers(ordering_test::kNumberOfWorkers) |
1408 | .enforce_ordering(true)); |
1409 | std::vector<size_t> output; |
1410 | for (size_t value : *data_loader) { |
1411 | output.push_back(value); |
1412 | } |
1413 | std::vector<size_t> expected(ordering_test::kNumberOfWorkers); |
1414 | std::iota(expected.begin(), expected.end(), size_t(0)); |
1415 | ASSERT_EQ(expected, output); |
1416 | } |
1417 | |
1418 | TEST(DataLoaderTest, Reset) { |
1419 | DummyDataset dataset; |
1420 | auto data_loader = |
1421 | torch::data::make_data_loader(dataset, dataset.size().value() / 2); |
1422 | auto end = data_loader->end(); |
1423 | |
1424 | auto iterator = data_loader->begin(); |
1425 | ASSERT_NE(iterator, end); |
1426 | ASSERT_NE(++iterator, end); |
1427 | ASSERT_EQ(++iterator, end); |
1428 | |
1429 | iterator = data_loader->begin(); |
1430 | ASSERT_NE(iterator, end); |
1431 | ASSERT_NE(++iterator, end); |
1432 | ASSERT_EQ(++iterator, end); |
1433 | |
1434 | iterator = data_loader->begin(); |
1435 | ASSERT_NE(iterator, end); |
1436 | ASSERT_NE(++iterator, end); |
1437 | ASSERT_EQ(++iterator, end); |
1438 | } |
1439 | |
1440 | TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) { |
1441 | struct D : datasets::Dataset<DummyDataset, int> { |
1442 | int get(size_t index) override { |
1443 | throw std::invalid_argument("badness" ); |
1444 | } |
1445 | torch::optional<size_t> size() const override { |
1446 | return 100; |
1447 | } |
1448 | }; |
1449 | |
1450 | auto data_loader = torch::data::make_data_loader( |
1451 | D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2)); |
1452 | auto iterator = data_loader->begin(); |
1453 | |
1454 | try { |
1455 | (void)*iterator; |
1456 | } catch (torch::data::WorkerException& e) { |
1457 | ASSERT_EQ( |
1458 | e.what(), |
1459 | std::string("Caught exception in DataLoader worker thread. " |
1460 | "Original message: badness" )); |
1461 | // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
1462 | ASSERT_THROW( |
1463 | std::rethrow_exception(e.original_exception), std::invalid_argument); |
1464 | } |
1465 | } |
1466 | |
1467 | TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) { |
1468 | const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10; |
1469 | |
1470 | struct D : datasets::StatefulDataset<D, int, size_t> { |
1471 | torch::optional<int> get_batch(size_t) override { |
1472 | if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) { |
1473 | return counter++; |
1474 | } |
1475 | return torch::nullopt; |
1476 | } |
1477 | torch::optional<size_t> size() const override { |
1478 | return 100; |
1479 | } |
1480 | void reset() override { |
1481 | counter = 0; |
1482 | } |
1483 | void save(torch::serialize::OutputArchive& archive) const override{}; |
1484 | void load(torch::serialize::InputArchive& archive) override {} |
1485 | int counter = 0; |
1486 | }; |
1487 | |
1488 | auto data_loader = torch::data::make_data_loader(D{}); |
1489 | |
1490 | for (const auto i : c10::irange(10)) { |
1491 | const auto number_of_iterations = |
1492 | std::distance(data_loader->begin(), data_loader->end()); |
1493 | ASSERT_EQ( |
1494 | number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts) |
1495 | << "epoch " << i; |
1496 | } |
1497 | |
1498 | for (const int i : *data_loader) { |
1499 | ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts); |
1500 | } |
1501 | } |
1502 | |
1503 | TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) { |
1504 | const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10; |
1505 | const int kNumberOfWorkers = 4; |
1506 | |
1507 | struct D : datasets::StatefulDataset<D, int, size_t> { |
1508 | torch::optional<int> get_batch(size_t) override { |
1509 | std::lock_guard<std::mutex> lock(mutex); |
1510 | if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) { |
1511 | return counter++; |
1512 | } |
1513 | return torch::nullopt; |
1514 | } |
1515 | torch::optional<size_t> size() const override { |
1516 | return 100; |
1517 | } |
1518 | void reset() override { |
1519 | counter = 0; |
1520 | } |
1521 | void save(torch::serialize::OutputArchive& archive) const override{}; |
1522 | void load(torch::serialize::InputArchive& archive) override {} |
1523 | int counter = 0; |
1524 | std::mutex mutex; |
1525 | }; |
1526 | |
1527 | auto data_loader = torch::data::make_data_loader( |
1528 | torch::data::datasets::make_shared_dataset<D>(), |
1529 | DataLoaderOptions().workers(kNumberOfWorkers)); |
1530 | |
1531 | for (const auto i : c10::irange(10)) { |
1532 | const auto number_of_iterations = |
1533 | std::distance(data_loader->begin(), data_loader->end()); |
1534 | ASSERT_EQ( |
1535 | number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts) |
1536 | << "epoch " << i; |
1537 | } |
1538 | |
1539 | for (const int i : *data_loader) { |
1540 | ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts); |
1541 | } |
1542 | } |
1543 | |
1544 | TEST(DataLoaderTest, StatefulDatasetWithMap) { |
1545 | const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10; |
1546 | |
1547 | struct D : datasets::StatefulDataset<D, int, size_t> { |
1548 | torch::optional<int> get_batch(size_t) override { |
1549 | if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) { |
1550 | return counter++; |
1551 | } |
1552 | return torch::nullopt; |
1553 | } |
1554 | torch::optional<size_t> size() const override { |
1555 | return 100; |
1556 | } |
1557 | void reset() override { |
1558 | counter = 0; |
1559 | } |
1560 | void save(torch::serialize::OutputArchive& archive) const override{}; |
1561 | void load(torch::serialize::InputArchive& archive) override {} |
1562 | int counter = 0; |
1563 | }; |
1564 | |
1565 | auto data_loader = torch::data::make_data_loader( |
1566 | D().map(transforms::BatchLambda<int, std::string>( |
1567 | [](int x) { return std::to_string(x); })) |
1568 | .map(transforms::BatchLambda<std::string, torch::Tensor>( |
1569 | [](const std::string& x) { |
1570 | return torch::tensor(static_cast<int64_t>(std::stoi(x))); |
1571 | })), |
1572 | DataLoaderOptions{}); |
1573 | |
1574 | for (const auto i : c10::irange(10)) { |
1575 | const auto number_of_iterations = |
1576 | std::distance(data_loader->begin(), data_loader->end()); |
1577 | ASSERT_EQ( |
1578 | number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts) |
1579 | << "epoch " << i; |
1580 | } |
1581 | |
1582 | for (const torch::Tensor& t : *data_loader) { |
1583 | ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts); |
1584 | } |
1585 | } |
1586 | |
1587 | TEST(DataLoaderTest, StatefulDatasetWithCollate) { |
1588 | const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10; |
1589 | |
1590 | struct D : datasets::StatefulDataset<D> { |
1591 | torch::optional<std::vector<Example<>>> get_batch( |
1592 | size_t batch_size) override { |
1593 | if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) { |
1594 | counter += batch_size; |
1595 | std::vector<Example<>> batch( |
1596 | /*count=*/batch_size, |
1597 | Example<>{ |
1598 | torch::ones(batch_size + 1), torch::zeros(batch_size - 1)}); |
1599 | return batch; |
1600 | } |
1601 | return torch::nullopt; |
1602 | } |
1603 | torch::optional<size_t> size() const override { |
1604 | return 100; |
1605 | } |
1606 | void reset() override { |
1607 | counter = 0; |
1608 | } |
1609 | void save(torch::serialize::OutputArchive& archive) const override{}; |
1610 | void load(torch::serialize::InputArchive& archive) override {} |
1611 | int counter = 0; |
1612 | }; |
1613 | |
1614 | auto d = D().map(transforms::Stack<Example<>>()); |
1615 | |
1616 | const size_t kBatchSize = 5; |
1617 | |
1618 | // Notice that the `get_batch()` of the dataset returns a vector<Example>, but |
1619 | // the `Stack` collation stacks the tensors into one. |
1620 | torch::optional<Example<>> batch = d.get_batch(kBatchSize); |
1621 | ASSERT_TRUE(batch.has_value()); |
1622 | ASSERT_EQ(batch->data.size(0), kBatchSize); |
1623 | ASSERT_EQ(batch->data.size(1), kBatchSize + 1); |
1624 | ASSERT_EQ(batch->target.size(0), kBatchSize); |
1625 | ASSERT_EQ(batch->target.size(1), kBatchSize - 1); |
1626 | |
1627 | ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1))); |
1628 | ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1))); |
1629 | } |
1630 | |
1631 | // This test tests the core function for iterate through a chunk dataset. It |
1632 | // contains test cases with different parameter combination. (For example, |
1633 | // different prefetch count, batch size and data loader worker count). It |
1634 | // verifies the return batches size and content when the order is deterministic. |
1635 | TEST(DataLoaderTest, ChunkDataSetGetBatch) { |
1636 | // different prefetch count for testing. |
1637 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
1638 | const size_t prefetch_counts[] = {1, 2, 3, 4}; |
1639 | |
1640 | // different batch size for testing. |
1641 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
1642 | const size_t batch_sizes[] = {5, 7}; |
1643 | |
1644 | // test with/without worker threads |
1645 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
1646 | const size_t dataloader_worker_counts[] = {0, 2}; |
1647 | |
1648 | const size_t total_example_count = 35; |
1649 | DummyChunkDataReader data_reader; |
1650 | samplers::SequentialSampler sampler(0); |
1651 | |
1652 | // test functionality across epoch boundary |
1653 | const int epoch_count = 2; |
1654 | |
1655 | for (auto prefetch_count : prefetch_counts) { |
1656 | for (auto batch_size : batch_sizes) { |
1657 | for (auto dataloader_worker_count : dataloader_worker_counts) { |
1658 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
1659 | DummyChunkDataReader, |
1660 | samplers::SequentialSampler, |
1661 | samplers::SequentialSampler>> |
1662 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
1663 | DummyChunkDataReader, |
1664 | samplers::SequentialSampler, |
1665 | samplers::SequentialSampler>>( |
1666 | data_reader, |
1667 | sampler, |
1668 | sampler, |
1669 | datasets::ChunkDatasetOptions(prefetch_count, batch_size)); |
1670 | |
1671 | auto data_loader = torch::data::make_data_loader( |
1672 | dataset, |
1673 | DataLoaderOptions(batch_size).workers(dataloader_worker_count)); |
1674 | |
1675 | for (const auto epoch_index : c10::irange(epoch_count)) { |
1676 | (void)epoch_index; // Suppress unused variable warning |
1677 | std::vector<bool> result(total_example_count, false); |
1678 | int iteration_count = 0; |
1679 | for (auto iterator = data_loader->begin(); |
1680 | iterator != data_loader->end(); |
1681 | ++iterator, ++iteration_count) { |
1682 | DummyChunkDataReader::BatchType& batch = *iterator; |
1683 | ASSERT_EQ(batch.size(), batch_size); |
1684 | |
1685 | // When prefetch_count is equal to 1 and no worker thread, the batch |
1686 | // order is deterministic. So we can verify elements in each batch. |
1687 | if (prefetch_count == 1 && dataloader_worker_count == 0) { |
1688 | for (const auto j : c10::irange(batch_size)) { |
1689 | ASSERT_EQ(batch[j], iteration_count * batch_size + j); |
1690 | } |
1691 | } |
1692 | for (const auto j : c10::irange(batch_size)) { |
1693 | result[batch[j]] = true; |
1694 | } |
1695 | } |
1696 | |
1697 | for (auto data : result) { |
1698 | ASSERT_EQ(data, true); |
1699 | } |
1700 | } |
1701 | } |
1702 | } |
1703 | } |
1704 | } |
1705 | |
1706 | TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) { |
1707 | const size_t prefetch_count = 1; |
1708 | const size_t batch_size = 5; |
1709 | const size_t requested_batch_size = 6; |
1710 | |
1711 | DummyChunkDataReader data_reader; |
1712 | samplers::SequentialSampler sampler(0); |
1713 | |
1714 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
1715 | DummyChunkDataReader, |
1716 | samplers::SequentialSampler, |
1717 | samplers::SequentialSampler>> |
1718 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
1719 | DummyChunkDataReader, |
1720 | samplers::SequentialSampler, |
1721 | samplers::SequentialSampler>>( |
1722 | data_reader, |
1723 | sampler, |
1724 | sampler, |
1725 | datasets::ChunkDatasetOptions(prefetch_count, batch_size)); |
1726 | |
1727 | auto data_loader = torch::data::make_data_loader( |
1728 | dataset, DataLoaderOptions(requested_batch_size).workers(0)); |
1729 | |
1730 | std::string exception_msg = |
1731 | "The requested batch size does not match with the initialized batch " |
1732 | "size.\n The requested batch size is 6, while the dataset is created" |
1733 | " with batch size equal to 5" ; |
1734 | |
1735 | ASSERT_THROWS_WITH(*(data_loader->begin()), exception_msg); |
1736 | } |
1737 | |
1738 | TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) { |
1739 | struct DummyEmptyChunkDataReader : datasets::ChunkDataReader<int> { |
1740 | public: |
1741 | using BatchType = datasets::ChunkDataReader<int>::ChunkType; |
1742 | |
1743 | BatchType read_chunk(size_t chunk_index) override { |
1744 | return {}; |
1745 | } |
1746 | |
1747 | size_t chunk_count() override { |
1748 | return 1; |
1749 | }; |
1750 | |
1751 | void reset() override{}; |
1752 | }; |
1753 | |
1754 | const size_t prefetch_count = 1; |
1755 | const size_t batch_size = 5; |
1756 | DummyEmptyChunkDataReader data_reader; |
1757 | samplers::SequentialSampler sampler(0); |
1758 | |
1759 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
1760 | DummyEmptyChunkDataReader, |
1761 | samplers::SequentialSampler, |
1762 | samplers::SequentialSampler>> |
1763 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
1764 | DummyEmptyChunkDataReader, |
1765 | samplers::SequentialSampler, |
1766 | samplers::SequentialSampler>>( |
1767 | data_reader, |
1768 | sampler, |
1769 | sampler, |
1770 | datasets::ChunkDatasetOptions(prefetch_count, batch_size)); |
1771 | |
1772 | auto data_loader = torch::data::make_data_loader( |
1773 | dataset, DataLoaderOptions(batch_size).workers(0)); |
1774 | |
1775 | for (auto iterator = data_loader->begin(); iterator != data_loader->end(); |
1776 | ++iterator) { |
1777 | ASSERT_EQ(iterator->size(), 0); |
1778 | } |
1779 | } |
1780 | |
1781 | TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) { |
1782 | struct D : public datasets::ChunkDataReader<int> { |
1783 | public: |
1784 | using BatchType = datasets::ChunkDataReader<int>::ChunkType; |
1785 | |
1786 | BatchType read_chunk(size_t chunk_index) override { |
1787 | BatchType batch_data(10, 0); |
1788 | return batch_data; |
1789 | } |
1790 | |
1791 | size_t chunk_count() override { |
1792 | return 2; |
1793 | }; |
1794 | |
1795 | void reset() override{}; |
1796 | }; |
1797 | |
1798 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
1799 | const size_t batch_sizes[] = {17, 30}; |
1800 | D data_reader; |
1801 | samplers::SequentialSampler sampler(0); |
1802 | |
1803 | for (auto batch_size : batch_sizes) { |
1804 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
1805 | D, |
1806 | samplers::SequentialSampler, |
1807 | samplers::SequentialSampler>> |
1808 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
1809 | D, |
1810 | samplers::SequentialSampler, |
1811 | samplers::SequentialSampler>>( |
1812 | data_reader, |
1813 | sampler, |
1814 | sampler, |
1815 | datasets::ChunkDatasetOptions(1, batch_size)); |
1816 | |
1817 | auto data_loader = torch::data::make_data_loader( |
1818 | dataset, DataLoaderOptions(batch_size).workers(0)); |
1819 | |
1820 | for (auto iterator = data_loader->begin(); iterator != data_loader->end(); |
1821 | ++iterator) { |
1822 | DummyChunkDataReader::BatchType batch = *iterator; |
1823 | auto batch_size = batch.size(); |
1824 | if (batch_size == 17) { |
1825 | ASSERT_TRUE(batch.size() == 17 || batch.size() == 3); |
1826 | } |
1827 | if (batch_size == 30) { |
1828 | ASSERT_TRUE(batch.size() == 20); |
1829 | } |
1830 | } |
1831 | } |
1832 | } |
1833 | |
1834 | TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) { |
1835 | const size_t prefetch_count = 2; |
1836 | const size_t batch_size = 5; |
1837 | |
1838 | DummyChunkDataReader data_reader; |
1839 | samplers::SequentialSampler sampler(0); |
1840 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
1841 | DummyChunkDataReader, |
1842 | samplers::SequentialSampler, |
1843 | samplers::SequentialSampler>> |
1844 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
1845 | DummyChunkDataReader, |
1846 | samplers::SequentialSampler, |
1847 | samplers::SequentialSampler>>( |
1848 | data_reader, |
1849 | sampler, |
1850 | sampler, |
1851 | datasets::ChunkDatasetOptions(prefetch_count, batch_size)); |
1852 | |
1853 | samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler(); |
1854 | |
1855 | auto data_loader = torch::data::make_data_loader( |
1856 | dataset.map(transforms::BatchLambda< |
1857 | DummyChunkDataReader::BatchType, |
1858 | DummyChunkDataReader::DataType>( |
1859 | [](DummyChunkDataReader::BatchType batch) { |
1860 | return std::accumulate(batch.begin(), batch.end(), 0); |
1861 | })), |
1862 | DataLoaderOptions(batch_size).workers(0)); |
1863 | |
1864 | // before we start, the index should be 0. |
1865 | ASSERT_EQ(chunk_sampler.index(), 0); |
1866 | |
1867 | size_t sum = 0; |
1868 | for (auto iterator = data_loader->begin(); iterator != data_loader->end(); |
1869 | ++iterator) { |
1870 | sum += *iterator; |
1871 | } |
1872 | ASSERT_EQ(sum, 595); // sum([0, 35)) |
1873 | // 3 chunks, and when exhausted the value is already incremented. |
1874 | ASSERT_EQ(chunk_sampler.index(), 3); |
1875 | } |
1876 | |
1877 | TEST(DataLoaderTest, ChunkDatasetDoesNotHang) { |
1878 | const size_t prefetch_count = 2; |
1879 | const size_t batch_size = 5; |
1880 | // this will make the preloaders to wait till the `get_batch()` calls. |
1881 | const size_t cache_size = 10; |
1882 | |
1883 | DummyChunkDataReader data_reader; |
1884 | samplers::SequentialSampler sampler(0); |
1885 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
1886 | DummyChunkDataReader, |
1887 | samplers::SequentialSampler, |
1888 | samplers::SequentialSampler>> |
1889 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
1890 | DummyChunkDataReader, |
1891 | samplers::SequentialSampler, |
1892 | samplers::SequentialSampler>>( |
1893 | data_reader, |
1894 | sampler, |
1895 | sampler, |
1896 | datasets::ChunkDatasetOptions( |
1897 | prefetch_count, batch_size, cache_size)); |
1898 | |
1899 | auto data_loader = torch::data::make_data_loader( |
1900 | dataset.map(transforms::BatchLambda< |
1901 | DummyChunkDataReader::BatchType, |
1902 | DummyChunkDataReader::DataType>( |
1903 | [](DummyChunkDataReader::BatchType batch) { |
1904 | return std::accumulate(batch.begin(), batch.end(), 0); |
1905 | })), |
1906 | DataLoaderOptions(batch_size).workers(0)); |
1907 | // simply creates the iterator but no iteration. chunk preloaders are waiting |
1908 | // to fill the batch buffer but it is not draining. Still we need to exit |
1909 | // cleanly. |
1910 | auto iterator = data_loader->begin(); |
1911 | } |
1912 | |
1913 | // Test ChunkDataset save function. |
1914 | // Note [save/load ChunkDataset as ChunkSampler]: |
1915 | // The chunk sampler inside ChunkDataset is used in a separate thread pool other |
1916 | // than the main thread. Thus it is very hard to accurately estimate its status |
1917 | // when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of |
1918 | // testing, we utilize the implementation fact that the file format for sampler |
1919 | // serialization is the same as ChunkDataset serialization, and manually control |
1920 | // the chunk sampler by calling the sampler's save/load method for value |
1921 | // validation. This is only for testing the specific save/load functionality. In |
1922 | // real user case, the user should still use matching ChunkDataset::save and |
1923 | // ChunkDataset::load method. |
1924 | TEST(DataLoaderTest, ChunkDatasetSave) { |
1925 | const size_t chunk_count_ = 6; |
1926 | const size_t chunk_size = 10; |
1927 | |
1928 | struct DummyTestChunkDataReader : datasets::ChunkDataReader<int> { |
1929 | public: |
1930 | using BatchType = datasets::ChunkDataReader<int>::ChunkType; |
1931 | |
1932 | BatchType read_chunk(size_t chunk_index) override { |
1933 | return batch_data_; |
1934 | } |
1935 | |
1936 | size_t chunk_count() override { |
1937 | return chunk_count_; |
1938 | }; |
1939 | |
1940 | void reset() override{}; |
1941 | BatchType batch_data_ = BatchType(chunk_size, 0); |
1942 | }; |
1943 | |
1944 | const size_t prefetch_count = 1; |
1945 | const size_t batch_size = chunk_size; |
1946 | const size_t dataloader_worker_count = 0; |
1947 | samplers::SequentialSampler sampler(0); |
1948 | const int epoch_count = 2; |
1949 | |
1950 | DummyTestChunkDataReader data_reader; |
1951 | |
1952 | // tested save_intervals |
1953 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
1954 | const size_t save_intervals[] = {1, 2}; |
1955 | |
1956 | using datasets::ChunkDatasetOptions; |
1957 | |
1958 | for (auto save_interval : save_intervals) { |
1959 | auto tempfile = c10::make_tempfile(); |
1960 | |
1961 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
1962 | DummyTestChunkDataReader, |
1963 | samplers::SequentialSampler, |
1964 | samplers::SequentialSampler>> |
1965 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
1966 | DummyTestChunkDataReader, |
1967 | samplers::SequentialSampler, |
1968 | samplers::SequentialSampler>>( |
1969 | data_reader, |
1970 | sampler, |
1971 | sampler, |
1972 | ChunkDatasetOptions( |
1973 | prefetch_count, batch_size, chunk_size /*cache size*/)); |
1974 | |
1975 | auto data_loader = torch::data::make_data_loader( |
1976 | dataset, |
1977 | DataLoaderOptions(batch_size).workers(dataloader_worker_count)); |
1978 | |
1979 | for (const auto epoch_index : c10::irange(epoch_count)) { |
1980 | (void)epoch_index; // Suppress unused variable warning |
1981 | unsigned iteration_count = 0; |
1982 | for (auto iterator = data_loader->begin(); iterator != data_loader->end(); |
1983 | ++iterator, ++iteration_count) { |
1984 | if ((iteration_count + 1) % save_interval == 0) { |
1985 | torch::save(*dataset, tempfile.name); |
1986 | |
1987 | samplers::SequentialSampler new_sampler(0); |
1988 | |
1989 | // See Note [save/load ChunkDataset as ChunkSampler] |
1990 | torch::load(new_sampler, tempfile.name); |
1991 | |
1992 | // Verify save logic. For ChunkDataset, the chunk data is stored in a |
1993 | // cache inside the dataset. One pool of threads are constantly |
1994 | // writing to the cache, and a different pool of thread are constantly |
1995 | // reading from the cache. Due to the nature of asynchronization, at |
1996 | // the time of get_batch(), which chunk is written to the cache is not |
1997 | // fully deterministic. |
1998 | // But we can still calculate a restricted window on the expected |
1999 | // output, hence verify the logic. In this test, the cache size is |
2000 | // configured to be the same as chunk size and batch size. So the |
2001 | // chunk data is written to the cache one by one. Only the current |
2002 | // batch is retrieved, the next chunk is written. Now in iteration 0, |
2003 | // after the first batch is retrieved, when we save the dataset |
2004 | // statues, there are three possible scenarios for the writer thread: |
2005 | // 1. it hasn't started loading the next chunk data yet, so the |
2006 | // sequential sampler index is still 0; |
2007 | // 2. it started to load the second chunk, so the sequencial sampler |
2008 | // index is at 1; |
2009 | // 3. it finished loading the second chunk, and start to load the |
2010 | // third chunk, because the cache is still fully occupied by the data |
2011 | // from the second chunk, it is waiting to write to the cache. At this |
2012 | // point, the sampler index is at 2. |
2013 | // So now we have a window of [0, 2], which is what we expected the |
2014 | // sampler to save the index from. Now noted for sequential sampler, |
2015 | // it advances to the next index automatically in the call next(). So |
2016 | // when save the index, it saves the next index in stead of the |
2017 | // current one. In other word, after getting the first index from |
2018 | // sequential sampler, it already moves to the second index. So when |
2019 | // we save it, it is the second index we save. As a result, |
2020 | // we need to advance the window by one. Now we have the expected |
2021 | // window of [1, 3]. |
2022 | // This analysis applies to all scenarios. So extend it to a more |
2023 | // general case: the expected saved index should falling into the |
2024 | // range of [iteration, iteration + 3], which is the validation |
2025 | // below. |
2026 | ASSERT_TRUE( |
2027 | new_sampler.index() >= iteration_count + 1 && |
2028 | new_sampler.index() <= iteration_count + 3); |
2029 | } |
2030 | } |
2031 | } |
2032 | } |
2033 | } |
2034 | |
2035 | // Test ChunkDataset load function. |
2036 | TEST(DataLoaderTest, ChunkDatasetLoad) { |
2037 | auto tempfile = c10::make_tempfile(); |
2038 | |
2039 | const size_t prefetch_count = 1; |
2040 | const size_t batch_size = 10; |
2041 | const size_t dataloader_worker_count = 0; |
2042 | |
2043 | DummyChunkDataReader data_reader; |
2044 | samplers::SequentialSampler sampler(0); |
2045 | |
2046 | const size_t skipped_chunk = 2; |
2047 | |
2048 | // Configure sampler to skip 2 chunks |
2049 | { |
2050 | sampler.reset(data_reader.chunk_count()); |
2051 | sampler.next(skipped_chunk); |
2052 | |
2053 | // See Note [save/load ChunkDataset as ChunkSampler] |
2054 | torch::save(sampler, tempfile.name); |
2055 | } |
2056 | |
2057 | // test functionality across epoch boundary. The first epoch should be |
2058 | // affected by the checkpoint, but the second should start normally. |
2059 | const int epoch_count = 2; |
2060 | |
2061 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
2062 | DummyChunkDataReader, |
2063 | samplers::SequentialSampler, |
2064 | samplers::SequentialSampler>> |
2065 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
2066 | DummyChunkDataReader, |
2067 | samplers::SequentialSampler, |
2068 | samplers::SequentialSampler>>( |
2069 | data_reader, |
2070 | sampler, |
2071 | sampler, |
2072 | datasets::ChunkDatasetOptions( |
2073 | prefetch_count, batch_size, 20 /*cache size*/)); |
2074 | |
2075 | torch::load(*dataset, tempfile.name); |
2076 | |
2077 | auto data_loader = torch::data::make_data_loader( |
2078 | dataset, DataLoaderOptions(batch_size).workers(dataloader_worker_count)); |
2079 | |
2080 | for (const auto epoch_index : c10::irange(epoch_count)) { |
2081 | int iteration_count = 0; |
2082 | |
2083 | // For the first epoch, the returned batch should be returned from the |
2084 | // third chunk, because the check point skipped the first two chunks. But |
2085 | // for the next epoch, it should start from the first batch. |
2086 | int initial_value = epoch_index == 0 ? 15 : 0; |
2087 | |
2088 | for (auto iterator = data_loader->begin(); iterator != data_loader->end(); |
2089 | ++iterator, ++iteration_count) { |
2090 | DummyChunkDataReader::BatchType batch = *iterator; |
2091 | |
2092 | std::vector<int> expected_result; |
2093 | size_t expected_size = (epoch_index > 0 && iteration_count == 3) ? 5 : 10; |
2094 | expected_result.resize(expected_size); |
2095 | std::iota(expected_result.begin(), expected_result.end(), initial_value); |
2096 | |
2097 | ASSERT_EQ(batch.size(), expected_result.size()); |
2098 | ASSERT_TRUE( |
2099 | std::equal(batch.begin(), batch.end(), expected_result.begin())); |
2100 | |
2101 | initial_value += batch_size; |
2102 | } |
2103 | } |
2104 | |
2105 | samplers::SequentialSampler new_sampler(0); |
2106 | |
2107 | // See Note [save/load ChunkDataset as ChunkSampler] |
2108 | torch::load(new_sampler, tempfile.name); |
2109 | |
2110 | ASSERT_EQ(new_sampler.index(), skipped_chunk); |
2111 | } |
2112 | |
2113 | TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) { |
2114 | const size_t chunk_size = 5; |
2115 | const size_t batch_size = 5; |
2116 | |
2117 | class S : public samplers::Sampler<> { |
2118 | public: |
2119 | explicit S(size_t size) : size_(size), index_(0){}; |
2120 | |
2121 | void reset(torch::optional<size_t> new_size = torch::nullopt) override { |
2122 | if (new_size.has_value()) { |
2123 | size_ = *new_size; |
2124 | } |
2125 | indices_.resize(size_); |
2126 | size_t index = 0; |
2127 | |
2128 | // Repeatly sample every 5 indices. |
2129 | for (const auto i : c10::irange(batch_size)) { |
2130 | for (size_t j = 0; j < size_ / batch_size; ++j) { |
2131 | indices_[index++] = i + batch_size * j; |
2132 | } |
2133 | } |
2134 | index_ = 0; |
2135 | } |
2136 | |
2137 | // Returns the next batch of indices. |
2138 | torch::optional<std::vector<size_t>> next(size_t batch_size) override { |
2139 | const auto remaining_indices = size_ - index_; |
2140 | if (remaining_indices == 0) { |
2141 | return torch::nullopt; |
2142 | } |
2143 | auto return_size = std::min(batch_size, remaining_indices); |
2144 | std::vector<size_t> index_batch( |
2145 | indices_.begin() + index_, indices_.begin() + index_ + return_size); |
2146 | index_ += return_size; |
2147 | |
2148 | return index_batch; |
2149 | } |
2150 | |
2151 | void save(torch::serialize::OutputArchive& archive) const override {} |
2152 | void load(torch::serialize::InputArchive& archive) override {} |
2153 | |
2154 | private: |
2155 | size_t size_; |
2156 | std::vector<size_t> indices_; |
2157 | size_t index_{0}; |
2158 | }; |
2159 | |
2160 | struct D : public datasets::ChunkDataReader<int> { |
2161 | public: |
2162 | using BatchType = datasets::ChunkDataReader<int>::ChunkType; |
2163 | D(size_t chunk_count) : chunk_count_(chunk_count) {} |
2164 | |
2165 | BatchType read_chunk(size_t chunk_index) override { |
2166 | BatchType batch_data(chunk_size, chunk_index); |
2167 | return batch_data; |
2168 | } |
2169 | |
2170 | size_t chunk_count() override { |
2171 | return chunk_count_; |
2172 | }; |
2173 | |
2174 | void reset() override{}; |
2175 | size_t chunk_count_; |
2176 | }; |
2177 | |
2178 | const size_t prefetch_count = 1; |
2179 | const size_t cache_size = 10; |
2180 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
2181 | const size_t cross_chunk_shuffle_counts[] = {2, 3}; |
2182 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
2183 | const size_t chunk_counts[] = {3, 4, 5}; |
2184 | |
2185 | samplers::SequentialSampler chunk_sampler(0); |
2186 | S example_sampler(0); |
2187 | |
2188 | for (auto chunk_count : chunk_counts) { |
2189 | for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) { |
2190 | D data_reader(chunk_count); |
2191 | |
2192 | datasets::SharedBatchDataset< |
2193 | datasets::ChunkDataset<D, samplers::SequentialSampler, S>> |
2194 | dataset = datasets::make_shared_dataset< |
2195 | datasets::ChunkDataset<D, samplers::SequentialSampler, S>>( |
2196 | data_reader, |
2197 | chunk_sampler, |
2198 | example_sampler, |
2199 | datasets::ChunkDatasetOptions( |
2200 | prefetch_count, |
2201 | batch_size, |
2202 | cache_size, |
2203 | cross_chunk_shuffle_count)); |
2204 | |
2205 | auto data_loader = torch::data::make_data_loader( |
2206 | dataset, DataLoaderOptions(batch_size).workers(0)); |
2207 | |
2208 | std::vector<int> result; |
2209 | for (auto iterator = data_loader->begin(); iterator != data_loader->end(); |
2210 | ++iterator) { |
2211 | auto batch_result = *iterator; |
2212 | std::copy( |
2213 | batch_result.begin(), |
2214 | batch_result.end(), |
2215 | std::back_inserter(result)); |
2216 | } |
2217 | |
2218 | std::vector<int> expected_result; |
2219 | { |
2220 | // construct expected result |
2221 | for (const auto i : c10::irange( |
2222 | (chunk_count + cross_chunk_shuffle_count - 1) / |
2223 | cross_chunk_shuffle_count)) { |
2224 | for (const auto j : c10::irange(chunk_size)) { |
2225 | (void)j; // Suppress unused variable warning |
2226 | for (const auto k : c10::irange(cross_chunk_shuffle_count)) { |
2227 | if (i * cross_chunk_shuffle_count + k < chunk_count) { |
2228 | expected_result.push_back(i * cross_chunk_shuffle_count + k); |
2229 | } |
2230 | } |
2231 | } |
2232 | } |
2233 | } |
2234 | |
2235 | ASSERT_EQ(result.size(), expected_result.size()); |
2236 | ASSERT_TRUE( |
2237 | std::equal(result.begin(), result.end(), expected_result.begin())); |
2238 | } |
2239 | } |
2240 | } |
2241 | |
2242 | TEST(DataLoaderTest, CustomPreprocessPolicy) { |
2243 | const size_t chunk_size = 5; |
2244 | const size_t batch_size = 10; |
2245 | |
2246 | struct D : public datasets::ChunkDataReader<int> { |
2247 | public: |
2248 | using BatchType = datasets::ChunkDataReader<int>::ChunkType; |
2249 | D(size_t chunk_count) : chunk_count_(chunk_count) {} |
2250 | |
2251 | BatchType read_chunk(size_t chunk_index) override { |
2252 | BatchType batch_data(chunk_size); |
2253 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) |
2254 | auto rand_gen = []() { return std::rand() % 100; }; |
2255 | std::generate(batch_data.begin(), batch_data.end(), rand_gen); |
2256 | return batch_data; |
2257 | } |
2258 | |
2259 | size_t chunk_count() override { |
2260 | return chunk_count_; |
2261 | }; |
2262 | |
2263 | void reset() override{}; |
2264 | size_t chunk_count_; |
2265 | }; |
2266 | |
2267 | // custom preprocessing policy - sort the data ascendingly |
2268 | auto sorting_policy = [](std::vector<int>& raw_batch_data) { |
2269 | std::sort(raw_batch_data.begin(), raw_batch_data.end()); |
2270 | }; |
2271 | std::function<void(std::vector<int>&)> policy_function = sorting_policy; |
2272 | |
2273 | const size_t prefetch_count = 1; |
2274 | const size_t cache_size = 10; |
2275 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
2276 | const size_t cross_chunk_shuffle_counts[] = {1, 2}; |
2277 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
2278 | const size_t chunk_counts[] = {3, 4}; |
2279 | |
2280 | samplers::SequentialSampler chunk_sampler(0); |
2281 | |
2282 | for (auto chunk_count : chunk_counts) { |
2283 | for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) { |
2284 | D data_reader(chunk_count); |
2285 | |
2286 | datasets::SharedBatchDataset<datasets::ChunkDataset< |
2287 | D, |
2288 | samplers::SequentialSampler, |
2289 | samplers::SequentialSampler>> |
2290 | dataset = datasets::make_shared_dataset<datasets::ChunkDataset< |
2291 | D, |
2292 | samplers::SequentialSampler, |
2293 | samplers::SequentialSampler>>( |
2294 | data_reader, |
2295 | chunk_sampler, |
2296 | chunk_sampler, |
2297 | datasets::ChunkDatasetOptions( |
2298 | prefetch_count, |
2299 | batch_size, |
2300 | cache_size, |
2301 | cross_chunk_shuffle_count), |
2302 | policy_function); |
2303 | |
2304 | auto data_loader = torch::data::make_data_loader( |
2305 | dataset, DataLoaderOptions(batch_size).workers(0)); |
2306 | |
2307 | std::vector<int> result; |
2308 | for (auto iterator = data_loader->begin(); iterator != data_loader->end(); |
2309 | ++iterator) { |
2310 | auto batch_result = *iterator; |
2311 | if (batch_result.size() > chunk_size * cross_chunk_shuffle_count) { |
2312 | for (unsigned i = 0; i < batch_result.size(); i += chunk_size) { |
2313 | ASSERT_TRUE(std::is_sorted( |
2314 | batch_result.begin() + i, |
2315 | batch_result.begin() + i + chunk_size)); |
2316 | } |
2317 | } else { |
2318 | ASSERT_TRUE(std::is_sorted(batch_result.begin(), batch_result.end())); |
2319 | } |
2320 | } |
2321 | } |
2322 | } |
2323 | } |
2324 | |