1#include <gtest/gtest.h>
2
3#include <torch/torch.h>
4
5#include <test/cpp/api/support.h>
6
7#include <c10/util/ArrayRef.h>
8#include <c10/util/irange.h>
9#include <c10/util/tempfile.h>
10
11#include <algorithm>
12#include <chrono>
13#include <future>
14#include <iostream>
15#include <iterator>
16#include <limits>
17#include <mutex>
18#include <numeric>
19#include <stdexcept>
20#include <string>
21#include <thread>
22#include <unordered_set>
23#include <vector>
24
25using namespace torch::data; // NOLINT
26
27const std::chrono::milliseconds kMillisecond(1);
28
29struct DummyDataset : datasets::Dataset<DummyDataset, int> {
30 explicit DummyDataset(size_t size = 100) : size_(size) {}
31
32 int get(size_t index) override {
33 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
34 return 1 + index;
35 }
36 torch::optional<size_t> size() const override {
37 return size_;
38 }
39
40 size_t size_;
41};
42
43TEST(DataTest, DatasetCallsGetCorrectly) {
44 DummyDataset d;
45 std::vector<int> batch = d.get_batch({0, 1, 2, 3, 4});
46 std::vector<int> expected = {1, 2, 3, 4, 5};
47 ASSERT_EQ(batch, expected);
48}
49
50TEST(DataTest, TransformCallsGetApplyCorrectly) {
51 struct T : transforms::Transform<int, std::string> {
52 std::string apply(int input) override {
53 return std::to_string(input);
54 }
55 };
56
57 auto d = DummyDataset{}.map(T{});
58 std::vector<std::string> batch = d.get_batch({0, 1, 2, 3, 4});
59 std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
60 ASSERT_EQ(batch, expected);
61}
62
63// dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk
64// contains 10, 5, 20 examples respectively.
65
66struct DummyChunkDataReader : public datasets::ChunkDataReader<int> {
67 public:
68 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
69 using DataType = datasets::ChunkDataReader<int>::ExampleType;
70
71 /// Read an entire chunk.
72 BatchType read_chunk(size_t chunk_index) override {
73 BatchType batch_data;
74 int start_index = chunk_index == 0
75 ? 0
76 // NOLINTNEXTLINE(bugprone-fold-init-type)
77 : std::accumulate(chunk_sizes, chunk_sizes + chunk_index, 0);
78
79 batch_data.resize(chunk_sizes[chunk_index]);
80
81 std::iota(batch_data.begin(), batch_data.end(), start_index);
82
83 return batch_data;
84 }
85
86 size_t chunk_count() override {
87 return chunk_count_;
88 };
89
90 void reset() override{};
91
92 const static size_t chunk_count_ = 3;
93 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
94 size_t chunk_sizes[chunk_count_] = {10, 5, 20};
95};
96
97TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
98 DummyChunkDataReader data_reader;
99 samplers::SequentialSampler sampler(0);
100
101 auto initialization_function = [&](size_t preloader_count,
102 size_t batch_size,
103 size_t cache_size,
104 size_t cross_chunk_shuffle_count = 1) {
105 datasets::SharedBatchDataset<datasets::ChunkDataset<
106 DummyChunkDataReader,
107 samplers::SequentialSampler,
108 samplers::SequentialSampler>>
109 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
110 DummyChunkDataReader,
111 samplers::SequentialSampler,
112 samplers::SequentialSampler>>(
113 data_reader,
114 sampler,
115 sampler,
116 datasets::ChunkDatasetOptions(
117 preloader_count,
118 batch_size,
119 cache_size,
120 cross_chunk_shuffle_count));
121 };
122
123 ASSERT_THROWS_WITH(
124 initialization_function(0, 1, 1),
125 "Preloader count is 0. At least one preloader needs to be specified.");
126
127 ASSERT_THROWS_WITH(
128 initialization_function(1, 0, 1),
129 "Batch size is 0. A positive batch size needs to be specified.");
130
131 ASSERT_THROWS_WITH(
132 initialization_function(1, 1, 0),
133 "Cache size is 0. A positive cache size needs to be specified.");
134
135 ASSERT_THROWS_WITH(
136 initialization_function(1, 10, 5),
137 "Cache size is less than batch size. Cache needs to be large enough to "
138 "hold at least one batch.");
139 ASSERT_THROWS_WITH(
140 initialization_function(1, 10, 20, 0),
141 "cross_chunk_shuffle_count needs to be greater than 0.");
142}
143
144struct InfiniteStreamDataset
145 : datasets::StreamDataset<InfiniteStreamDataset, std::vector<int>> {
146 std::vector<int> get_batch(size_t batch_size) override {
147 std::vector<int> batch(batch_size);
148 for (auto& i : batch) {
149 i = counter++;
150 }
151 return batch;
152 }
153
154 torch::optional<size_t> size() const override {
155 return torch::nullopt;
156 }
157
158 size_t counter = 0;
159};
160
161TEST(DataTest, InfiniteStreamDataset) {
162 const size_t kBatchSize = 13;
163
164 auto dataset = InfiniteStreamDataset().map(
165 transforms::Lambda<int>([](int x) { return x + 1; }));
166
167 auto data_loader = torch::data::make_data_loader(
168 std::move(dataset),
169 samplers::StreamSampler(/*epoch_size=*/39),
170 kBatchSize);
171
172 size_t batch_index = 0;
173 for (auto& batch : *data_loader) {
174 ASSERT_LT(batch_index, 3);
175 ASSERT_EQ(batch.size(), kBatchSize);
176 for (const auto j : c10::irange(kBatchSize)) {
177 ASSERT_EQ(batch.at(j), 1 + (batch_index * kBatchSize) + j);
178 }
179 batch_index += 1;
180 }
181 ASSERT_EQ(batch_index, 3);
182}
183
184TEST(DataTest, NoSequencerIsIdentity) {
185 using namespace torch::data::detail::sequencers; // NOLINT
186 NoSequencer<int> no_sequencer;
187 const auto value = no_sequencer.next([] { return 5; }).value();
188 ASSERT_EQ(value, 5);
189}
190
191TEST(DataTest, OrderedSequencerIsSetUpWell) {
192 using namespace torch::data::detail::sequencers; // NOLINT
193 struct S {
194 size_t sequence_number;
195 };
196 const size_t kMaxJobs = 5;
197 OrderedSequencer<S> sequencer(kMaxJobs);
198 ASSERT_EQ(sequencer.next_sequence_number_, 0);
199 ASSERT_EQ(sequencer.buffer_.size(), kMaxJobs);
200}
201
202TEST(DataTest, OrderedSequencerReOrdersValues) {
203 using namespace torch::data::detail::sequencers; // NOLINT
204 struct S {
205 size_t sequence_number;
206 };
207 const size_t kMaxJobs = 5;
208 OrderedSequencer<S> sequencer(kMaxJobs);
209
210 std::vector<size_t> v = {0, 2, 4, 3, 1};
211 size_t index = 0;
212 auto getter = [&v, &index]() { return S{v.at(index++)}; };
213
214 // Let's say the sequence number matches for the batch one, then it should
215 // return immediately.
216 const auto batch = sequencer.next(getter);
217 ASSERT_EQ(batch.value().sequence_number, 0);
218 ASSERT_EQ(index, 1);
219
220 // Now it should call the getter until it gets the next value.
221 ASSERT_EQ(1, sequencer.next(getter).value().sequence_number);
222 ASSERT_EQ(index, 5);
223
224 // The next three should come in order.
225 for (size_t i = 2; i <= 4; ++i) {
226 // New value doesn't matter. In fact, it shouldn't be accessed.
227 ASSERT_EQ(i, sequencer.next(getter).value().sequence_number);
228 // The index doesn't change.
229 ASSERT_EQ(index, 5);
230 }
231}
232
233TEST(DataTest, BatchLambdaAppliesFunctionToBatch) {
234 using InputBatch = std::vector<int>;
235 using OutputBatch = std::string;
236 DummyDataset d;
237 auto e = d.map(transforms::BatchLambda<InputBatch, OutputBatch>(
238 [](std::vector<int> input) {
239 return std::to_string(std::accumulate(input.begin(), input.end(), 0));
240 }));
241 ASSERT_EQ(e.get_batch({1, 2, 3, 4, 5}), std::string("20"));
242}
243
244TEST(DataTest, LambdaAppliesFunctionToExample) {
245 auto d = DummyDataset().map(transforms::Lambda<int, std::string>(
246 static_cast<std::string (*)(int)>(std::to_string)));
247 std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
248 ASSERT_EQ(d.get_batch({0, 1, 2, 3, 4}), expected);
249}
250
251TEST(DataTest, CollateReducesBatch) {
252 auto d =
253 DummyDataset().map(transforms::Collate<int>([](std::vector<int> input) {
254 return std::accumulate(input.begin(), input.end(), 0);
255 }));
256 ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
257}
258
259TEST(DataTest, CollationReducesBatch) {
260 struct Summer : transforms::Collation<int> {
261 int apply_batch(std::vector<int> input) override {
262 return std::accumulate(input.begin(), input.end(), 0);
263 }
264 };
265 auto d = DummyDataset().map(Summer{});
266 ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
267}
268
269TEST(DataTest, SequentialSamplerReturnsIndicesInOrder) {
270 samplers::SequentialSampler sampler(10);
271 ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
272 ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({3, 4, 5, 6, 7}));
273 ASSERT_EQ(sampler.next(2).value(), std::vector<size_t>({8, 9}));
274 ASSERT_FALSE(sampler.next(2).has_value());
275}
276
277TEST(DataTest, SequentialSamplerReturnsLessValuesForLastBatch) {
278 samplers::SequentialSampler sampler(5);
279 ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
280 ASSERT_EQ(sampler.next(100).value(), std::vector<size_t>({3, 4}));
281 ASSERT_FALSE(sampler.next(2).has_value());
282}
283
284TEST(DataTest, SequentialSamplerResetsWell) {
285 samplers::SequentialSampler sampler(5);
286 ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
287 ASSERT_FALSE(sampler.next(2).has_value());
288 sampler.reset();
289 ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
290 ASSERT_FALSE(sampler.next(2).has_value());
291}
292
293TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) {
294 samplers::SequentialSampler sampler(5);
295 ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
296 ASSERT_FALSE(sampler.next(2).has_value());
297 sampler.reset(7);
298 ASSERT_EQ(
299 sampler.next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6}));
300 ASSERT_FALSE(sampler.next(2).has_value());
301 sampler.reset(3);
302 ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
303 ASSERT_FALSE(sampler.next(2).has_value());
304}
305
306TEST(DataTest, CanSaveAndLoadSequentialSampler) {
307 {
308 samplers::SequentialSampler a(10);
309 ASSERT_EQ(a.index(), 0);
310 std::stringstream stream;
311 torch::save(a, stream);
312
313 samplers::SequentialSampler b(10);
314 torch::load(b, stream);
315 ASSERT_EQ(b.index(), 0);
316 }
317 {
318 samplers::SequentialSampler a(10);
319 a.next(3);
320 a.next(4);
321 ASSERT_EQ(a.index(), 7);
322 std::stringstream stream;
323 torch::save(a, stream);
324
325 samplers::SequentialSampler b(10);
326 torch::load(b, stream);
327 ASSERT_EQ(b.index(), 7);
328 }
329}
330
331TEST(DataTest, RandomSamplerReturnsIndicesInCorrectRange) {
332 samplers::RandomSampler sampler(10);
333
334 std::vector<size_t> indices = sampler.next(3).value();
335 for (auto i : indices) {
336 ASSERT_GE(i, 0);
337 ASSERT_LT(i, 10);
338 }
339
340 indices = sampler.next(5).value();
341 for (auto i : indices) {
342 ASSERT_GE(i, 0);
343 ASSERT_LT(i, 10);
344 }
345
346 indices = sampler.next(2).value();
347 for (auto i : indices) {
348 ASSERT_GE(i, 0);
349 ASSERT_LT(i, 10);
350 }
351
352 ASSERT_FALSE(sampler.next(10).has_value());
353}
354
355TEST(DataTest, RandomSamplerReturnsLessValuesForLastBatch) {
356 samplers::RandomSampler sampler(5);
357 ASSERT_EQ(sampler.next(3).value().size(), 3);
358 ASSERT_EQ(sampler.next(100).value().size(), 2);
359 ASSERT_FALSE(sampler.next(2).has_value());
360}
361
362TEST(DataTest, RandomSamplerResetsWell) {
363 samplers::RandomSampler sampler(5);
364 ASSERT_EQ(sampler.next(5).value().size(), 5);
365 ASSERT_FALSE(sampler.next(2).has_value());
366 sampler.reset();
367 ASSERT_EQ(sampler.next(5).value().size(), 5);
368 ASSERT_FALSE(sampler.next(2).has_value());
369}
370
371TEST(DataTest, RandomSamplerResetsWithNewSizeWell) {
372 samplers::RandomSampler sampler(5);
373 ASSERT_EQ(sampler.next(5).value().size(), 5);
374 ASSERT_FALSE(sampler.next(2).has_value());
375 sampler.reset(7);
376 ASSERT_EQ(sampler.next(7).value().size(), 7);
377 ASSERT_FALSE(sampler.next(2).has_value());
378 sampler.reset(3);
379 ASSERT_EQ(sampler.next(3).value().size(), 3);
380 ASSERT_FALSE(sampler.next(2).has_value());
381}
382
383TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) {
384 {
385 samplers::RandomSampler a(10);
386
387 std::stringstream stream;
388 torch::save(a, stream);
389
390 samplers::RandomSampler b(10);
391 torch::load(b, stream);
392
393 ASSERT_EQ(a.next(10).value(), b.next(10).value());
394 }
395 {
396 samplers::RandomSampler a(10);
397 a.next(3);
398 ASSERT_EQ(a.index(), 3);
399
400 std::stringstream stream;
401 torch::save(a, stream);
402
403 samplers::RandomSampler b(10);
404 torch::load(b, stream);
405 ASSERT_EQ(b.index(), 3);
406
407 auto b_sequence = b.next(10).value();
408 ASSERT_EQ(b_sequence.size(), 7);
409 ASSERT_EQ(a.next(10).value(), b_sequence);
410 }
411}
412
413TEST(DataTest, StreamSamplerReturnsTheBatchSizeAndThenRemainder) {
414 samplers::StreamSampler sampler(/*epoch_size=*/100);
415 ASSERT_EQ(sampler.next(10).value(), 10);
416 ASSERT_EQ(sampler.next(2).value(), 2);
417 ASSERT_EQ(sampler.next(85).value(), 85);
418 ASSERT_EQ(sampler.next(123).value(), 3);
419 ASSERT_FALSE(sampler.next(1).has_value());
420}
421
422TEST(DataTest, StreamSamplerResetsWell) {
423 samplers::StreamSampler sampler(/*epoch_size=*/5);
424 ASSERT_EQ(sampler.next(5).value().size(), 5);
425 ASSERT_FALSE(sampler.next(2).has_value());
426 sampler.reset();
427 ASSERT_EQ(sampler.next(5).value().size(), 5);
428 ASSERT_FALSE(sampler.next(2).has_value());
429}
430
431TEST(DataTest, StreamSamplerResetsWithNewSizeWell) {
432 samplers::StreamSampler sampler(/*epoch_size=*/5);
433 ASSERT_EQ(sampler.next(5).value().size(), 5);
434 ASSERT_FALSE(sampler.next(2).has_value());
435 sampler.reset(7);
436 ASSERT_EQ(sampler.next(7).value().size(), 7);
437 ASSERT_FALSE(sampler.next(2).has_value());
438 sampler.reset(3);
439 ASSERT_EQ(sampler.next(3).value().size(), 3);
440 ASSERT_FALSE(sampler.next(2).has_value());
441}
442
443TEST(DataTest, TensorDatasetConstructsFromSingleTensor) {
444 datasets::TensorDataset dataset(torch::eye(5));
445 ASSERT_TRUE(
446 torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
447}
448
449TEST(DataTest, TensorDatasetConstructsFromInitializerListOfTensors) {
450 std::vector<torch::Tensor> vector = torch::eye(5).chunk(5);
451 datasets::TensorDataset dataset(vector);
452 ASSERT_TRUE(
453 torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
454}
455
456TEST(DataTest, StackTransformWorksForExample) {
457 struct D : public datasets::Dataset<D> {
458 Example<> get(size_t index) override {
459 return {tensor[index], 1 + tensor[index]};
460 }
461
462 torch::optional<size_t> size() const override {
463 return tensor.size(0);
464 }
465
466 torch::Tensor tensor{torch::eye(4)};
467 };
468
469 auto d = D().map(transforms::Stack<Example<>>());
470
471 Example<> batch = d.get_batch({0, 1});
472 ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
473 ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
474
475 Example<> second = d.get_batch({2, 3});
476 ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
477 ASSERT_TRUE(second.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 2, 4)));
478}
479
480TEST(DataTest, StackTransformWorksForTensorExample) {
481 auto d = datasets::TensorDataset(torch::eye(4))
482 .map(transforms::Stack<TensorExample>());
483
484 TensorExample batch = d.get_batch({0, 1});
485 ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
486
487 TensorExample second = d.get_batch({2, 3});
488 ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
489}
490
491// Template classes cannot be nested in functions.
492template <typename Target>
493struct T : transforms::TensorTransform<Target> {
494 torch::Tensor operator()(torch::Tensor input) override {
495 return input * 2;
496 }
497};
498
499struct TensorStringDataset
500 : datasets::
501 Dataset<TensorStringDataset, Example<torch::Tensor, std::string>> {
502 Example<torch::Tensor, std::string> get(size_t index) override {
503 return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
504 }
505
506 torch::optional<size_t> size() const override {
507 return 100;
508 }
509};
510
511TEST(DataTest, TensorTransformWorksForAnyTargetType) {
512 auto d = TensorStringDataset().map(T<std::string>{});
513 std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
514
515 ASSERT_EQ(batch.size(), 2);
516 ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
517 ASSERT_EQ(batch[0].target, "1");
518
519 ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
520 ASSERT_EQ(batch[1].target, "2");
521}
522
523TEST(DataTest, TensorLambdaWorksforAnyTargetType) {
524 auto d = TensorStringDataset().map(transforms::TensorLambda<std::string>(
525 [](torch::Tensor input) { return input * 2; }));
526 std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
527
528 ASSERT_EQ(batch.size(), 2);
529 ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
530 ASSERT_EQ(batch[0].target, "1");
531
532 ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
533 ASSERT_EQ(batch[1].target, "2");
534}
535
536struct DummyTensorDataset
537 : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> {
538 Example<torch::Tensor, int> get(size_t index) override {
539 const auto channels = static_cast<int64_t>(index);
540 torch::Tensor tensor =
541 (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
542 return {tensor, static_cast<int>(channels)};
543 }
544
545 torch::optional<size_t> size() const override {
546 return 100;
547 }
548};
549
550TEST(DataTest, NormalizeTransform) {
551 auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1));
552
553 // Works for zero (one implicit) channels
554 std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
555 ASSERT_EQ(output.size(), 1);
556 // (1 - 0.5) / 0.1 = 5
557 ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
558 << output[0].data;
559
560 // Works for one explicit channel
561 output = dataset.get_batch(1);
562 ASSERT_EQ(output.size(), 1);
563 ASSERT_EQ(output[0].data.size(0), 1);
564 ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
565 << output[0].data;
566
567 // Works for two channels with different moments
568 dataset = DummyTensorDataset().map(
569 transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2}));
570 output = dataset.get_batch(2);
571 ASSERT_EQ(output.size(), 1);
572 ASSERT_EQ(output[0].data.size(0), 2);
573 ASSERT_TRUE(output[0]
574 .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
575 .allclose(torch::ones({1, 4, 4}) * 5))
576 << output[0].data;
577 ASSERT_TRUE(output[0]
578 .data.slice(/*dim=*/0, /*start=*/1)
579 .allclose(torch::ones({1, 4, 4}) * -2.5))
580 << output[0].data;
581
582 // Works for three channels with one moment value
583 dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2));
584 output = dataset.get_batch(3);
585 ASSERT_EQ(output.size(), 1);
586 ASSERT_EQ(output[0].data.size(0), 3);
587 ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
588 << output[0].data;
589
590 // Works for three channels with different moments
591 dataset = DummyTensorDataset().map(
592 transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2}));
593 output = dataset.get_batch(3);
594 ASSERT_EQ(output.size(), 1);
595 ASSERT_EQ(output[0].data.size(0), 3);
596 ASSERT_TRUE(output[0]
597 .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
598 .allclose(torch::ones({1, 4, 4}) * 5))
599 << output[0].data;
600 ASSERT_TRUE(output[0]
601 .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2)
602 .allclose(torch::ones({1, 4, 4}) * -2.5))
603 << output[0].data;
604 ASSERT_TRUE(output[0]
605 .data.slice(/*dim=*/0, /*start=*/2)
606 .allclose(torch::ones({1, 4, 4}) * 12.5))
607 << output[0].data;
608}
609
610struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
611 UnCopyableDataset() = default;
612
613 UnCopyableDataset(const UnCopyableDataset&) = delete;
614 UnCopyableDataset& operator=(const UnCopyableDataset&) = delete;
615
616 UnCopyableDataset(UnCopyableDataset&&) = default;
617 UnCopyableDataset& operator=(UnCopyableDataset&&) = default;
618
619 // NOLINTNEXTLINE(modernize-use-override)
620 ~UnCopyableDataset() = default;
621
622 Example<> get(size_t index) override {
623 return {
624 torch::tensor({static_cast<int64_t>(index)}),
625 torch::tensor({static_cast<int64_t>(index)})};
626 }
627
628 torch::optional<size_t> size() const override {
629 return 100;
630 }
631};
632
633TEST(DataTest, MapDoesNotCopy) {
634 auto dataset = UnCopyableDataset()
635 .map(transforms::TensorLambda<>(
636 [](torch::Tensor tensor) { return tensor + 1; }))
637 .map(transforms::TensorLambda<>(
638 [](torch::Tensor tensor) { return tensor + 2; }))
639 .map(transforms::TensorLambda<>(
640 [](torch::Tensor tensor) { return tensor + 3; }));
641
642 auto data = dataset.get_batch(1).at(0).data;
643 ASSERT_EQ(data.numel(), 1);
644 ASSERT_EQ(data[0].item<float>(), 7);
645}
646
647TEST(DataTest, QueuePushAndPopFromSameThread) {
648 torch::data::detail::Queue<int> queue;
649 queue.push(1);
650 queue.push(2);
651 ASSERT_EQ(queue.pop(), 1);
652 ASSERT_EQ(queue.pop(), 2);
653}
654
655TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
656 torch::data::detail::Queue<int> queue;
657 ASSERT_THROWS_WITH(
658 queue.pop(10 * kMillisecond),
659 "Timeout in DataLoader queue while waiting for next batch "
660 "(timeout was 10 ms)");
661}
662
663TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
664 using torch::data::detail::Queue;
665
666 // First test: push batch and the pop in thread.
667 {
668 Queue<int> queue;
669 queue.push(1);
670 auto future =
671 std::async(std::launch::async, [&queue] { return queue.pop(); });
672 ASSERT_EQ(future.get(), 1);
673 }
674
675 // Second test: attempt to pop batch (and block), then push.
676 {
677 Queue<int> queue;
678 std::thread thread([&queue] {
679 std::this_thread::sleep_for(20 * kMillisecond);
680 queue.push(123);
681 });
682 ASSERT_EQ(queue.pop(), 123);
683 thread.join();
684 }
685}
686
687TEST(DataTest, QueueClearEmptiesTheQueue) {
688 torch::data::detail::Queue<int> queue;
689 queue.push(1);
690 queue.push(2);
691 queue.push(3);
692 ASSERT_EQ(queue.clear(), 3);
693 ASSERT_THROWS_WITH(queue.pop(1 * kMillisecond), "Timeout");
694}
695
696TEST(DataTest, DataShuttleCanPushAndPopJob) {
697 torch::data::detail::DataShuttle<int, int> shuttle;
698 shuttle.push_job(1);
699 shuttle.push_job(2);
700 ASSERT_EQ(shuttle.pop_job(), 1);
701 ASSERT_EQ(shuttle.pop_job(), 2);
702}
703
704TEST(DataTest, DataShuttleCanPushAndPopResult) {
705 torch::data::detail::DataShuttle<int, int> shuttle;
706 // pop_result() will only attempt to pop if there was a push_job() batch.
707 shuttle.push_job(1);
708 shuttle.push_job(2);
709
710 shuttle.pop_job();
711 shuttle.push_result(1);
712 ASSERT_EQ(shuttle.pop_result().value(), 1);
713
714 shuttle.pop_job();
715 shuttle.push_result(2);
716 ASSERT_EQ(shuttle.pop_result().value(), 2);
717}
718
719TEST(DataTest, DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight) {
720 torch::data::detail::DataShuttle<int, int> shuttle;
721 ASSERT_FALSE(shuttle.pop_result().has_value());
722 shuttle.push_job(1);
723 shuttle.pop_job();
724 shuttle.push_result(1);
725 ASSERT_EQ(shuttle.pop_result().value(), 1);
726 ASSERT_FALSE(shuttle.pop_result().has_value());
727 ASSERT_FALSE(shuttle.pop_result().has_value());
728}
729
730TEST(DataTest, DataShuttleDrainMeansPopResultReturnsNullopt) {
731 torch::data::detail::DataShuttle<int, int> shuttle;
732 shuttle.push_job(1);
733 shuttle.push_result(1);
734 shuttle.drain();
735 ASSERT_FALSE(shuttle.pop_result().has_value());
736}
737
738TEST(DataTest, DataShuttlePopResultTimesOut) {
739 torch::data::detail::DataShuttle<int, int> shuttle;
740 shuttle.push_job(1);
741 ASSERT_THROWS_WITH(shuttle.pop_result(10 * kMillisecond), "Timeout");
742}
743
744struct UncopyableDataset : datasets::Dataset<UncopyableDataset, int> {
745 UncopyableDataset(const std::string& /* unused */) {}
746
747 UncopyableDataset(UncopyableDataset&&) = default;
748 UncopyableDataset& operator=(UncopyableDataset&&) = default;
749
750 UncopyableDataset(const UncopyableDataset&) = delete;
751 UncopyableDataset& operator=(const UncopyableDataset&) = delete;
752
753 int get(size_t index) override {
754 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
755 return 1 + index;
756 }
757 torch::optional<size_t> size() const override {
758 return 100;
759 }
760};
761
762TEST(DataTest, SharedBatchDatasetReallyIsShared) {
763 // This test will only compile if we really are not making any copies.
764 // There is otherwise no logic to test and because it is not deterministic
765 // how many and when worker threads access the shareddataset, we don't have
766 // any additional assertions here.
767
768 auto shared_dataset =
769 torch::data::datasets::make_shared_dataset<UncopyableDataset>(
770 "uncopyable");
771
772 auto data_loader = torch::data::make_data_loader(
773 shared_dataset, torch::data::DataLoaderOptions().workers(3));
774
775 for (auto batch : *data_loader) {
776 /* exhaust */
777 }
778}
779
780TEST(DataTest, SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject) {
781 // This will not compile if a copy is made.
782 auto shared_dataset =
783 torch::data::datasets::make_shared_dataset<UncopyableDataset>(
784 UncopyableDataset("uncopyable"));
785 ASSERT_EQ(shared_dataset.size().value(), 100);
786}
787
788struct TestIndex : public torch::data::samplers::CustomBatchRequest {
789 explicit TestIndex(size_t offset, std::vector<size_t> index)
790 : offset(offset), index(std::move(index)) {}
791 size_t size() const override {
792 return index.size();
793 }
794 size_t offset;
795 std::vector<size_t> index;
796};
797
798struct TestIndexDataset
799 : datasets::BatchDataset<TestIndexDataset, std::vector<int>, TestIndex> {
800 explicit TestIndexDataset(size_t size) : data(size) {
801 std::iota(data.begin(), data.end(), size_t(0));
802 }
803 std::vector<int> get_batch(TestIndex index) override {
804 std::vector<int> batch;
805 for (auto i : index.index) {
806 batch.push_back(index.offset + data.at(i));
807 }
808 return batch;
809 }
810 torch::optional<size_t> size() const override {
811 return data.size();
812 }
813 std::vector<int> data;
814};
815
816struct TestIndexSampler : public samplers::Sampler<TestIndex> {
817 explicit TestIndexSampler(size_t size) : size_(size) {}
818 void reset(torch::optional<size_t> new_size = torch::nullopt) override {}
819 torch::optional<TestIndex> next(size_t batch_size) override {
820 if (index_ >= size_) {
821 return torch::nullopt;
822 }
823 std::vector<size_t> indices(batch_size);
824 std::iota(indices.begin(), indices.end(), size_t(0));
825 index_ += batch_size;
826 return TestIndex(batch_size, std::move(indices));
827 }
828 void save(torch::serialize::OutputArchive& archive) const override {}
829 void load(torch::serialize::InputArchive& archive) override {}
830 size_t index_ = 0;
831 size_t size_;
832};
833
834TEST(DataTest, CanUseCustomTypeAsIndexType) {
835 const int kBatchSize = 10;
836 auto data_loader = torch::data::make_data_loader(
837 TestIndexDataset(23), TestIndexSampler(23), kBatchSize);
838
839 for (auto batch : *data_loader) {
840 for (const auto j : c10::irange(kBatchSize)) {
841 ASSERT_EQ(batch.at(j), 10 + j);
842 }
843 }
844}
845
846TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
847 size_t sample_count = 10;
848 samplers::DistributedRandomSampler drs(sample_count);
849
850 std::vector<size_t> res;
851 torch::optional<std::vector<size_t>> idx;
852 while ((idx = drs.next(3)).has_value()) {
853 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
854 }
855
856 ASSERT_EQ(res.size(), sample_count);
857
858 std::sort(res.begin(), res.end());
859 for (const auto i : c10::irange(res.size())) {
860 ASSERT_EQ(res[i], i);
861 }
862}
863
864TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
865 size_t sample_count = 10;
866 size_t num_replicas = 3;
867
868 auto test_function = [&](bool allow_duplicates,
869 size_t local_sample_count,
870 std::vector<size_t>& output,
871 size_t batch_size) {
872 std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers;
873
874 for (const auto i : c10::irange(num_replicas)) {
875 samplers.emplace_back(
876 torch::make_unique<samplers::DistributedRandomSampler>(
877 sample_count, num_replicas, i, allow_duplicates));
878 }
879
880 std::vector<size_t> res;
881 for (const auto i : c10::irange(num_replicas)) {
882 (*samplers[i]).reset();
883 torch::optional<std::vector<size_t>> idx;
884 while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
885 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
886 }
887 ASSERT_EQ(res.size(), local_sample_count * (i + 1));
888 }
889 std::sort(res.begin(), res.end());
890 ASSERT_EQ(res, output);
891 };
892
893 for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
894 size_t local_sample_count =
895 static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
896 std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
897 test_function(true, local_sample_count, output1, batch_size);
898
899 local_sample_count =
900 static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
901 std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
902 test_function(false, local_sample_count, output2, batch_size);
903 }
904}
905
906TEST(DataTest, CanSaveAndLoadDistributedRandomSampler) {
907 {
908 samplers::DistributedRandomSampler a(10);
909 ASSERT_EQ(a.index(), 0);
910 std::stringstream stream;
911 torch::save(a, stream);
912
913 samplers::DistributedRandomSampler b(10);
914 torch::load(b, stream);
915 ASSERT_EQ(b.index(), 0);
916 }
917 {
918 samplers::DistributedRandomSampler a(10);
919 a.next(3);
920 a.next(4);
921 ASSERT_EQ(a.index(), 7);
922 std::stringstream stream;
923 torch::save(a, stream);
924
925 samplers::DistributedRandomSampler b(10);
926 torch::load(b, stream);
927 ASSERT_EQ(b.index(), 7);
928 }
929 {
930 samplers::DistributedRandomSampler a(10);
931 a.set_epoch(3);
932 std::stringstream stream;
933 torch::save(a, stream);
934
935 samplers::DistributedRandomSampler b(10);
936 torch::load(b, stream);
937 ASSERT_EQ(b.epoch(), 3);
938 }
939}
940
941TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
942 size_t sample_count = 10;
943 size_t batch_size = 3;
944 samplers::DistributedSequentialSampler dss(sample_count);
945
946 std::vector<size_t> res;
947 torch::optional<std::vector<size_t>> idx;
948 while ((idx = dss.next(batch_size)).has_value()) {
949 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
950 }
951
952 ASSERT_EQ(res.size(), sample_count);
953
954 std::sort(res.begin(), res.end());
955 for (const auto i : c10::irange(res.size())) {
956 ASSERT_EQ(res[i], i);
957 }
958}
959
960TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
961 size_t sample_count = 10;
962 size_t num_replicas = 3;
963
964 auto test_function = [&](bool allow_duplicates,
965 size_t local_sample_count,
966 std::vector<size_t>& output,
967 size_t batch_size) {
968 std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>>
969 samplers;
970
971 for (const auto i : c10::irange(num_replicas)) {
972 samplers.emplace_back(
973 torch::make_unique<samplers::DistributedSequentialSampler>(
974 sample_count, num_replicas, i, allow_duplicates));
975 }
976
977 std::vector<size_t> res;
978 for (const auto i : c10::irange(num_replicas)) {
979 (*samplers[i]).reset();
980 torch::optional<std::vector<size_t>> idx;
981 while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
982 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
983 }
984 ASSERT_EQ(res.size(), local_sample_count * (i + 1));
985 }
986 std::sort(res.begin(), res.end());
987 ASSERT_EQ(res, output);
988 };
989
990 for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
991 size_t local_sample_count =
992 static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
993 std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
994 test_function(true, local_sample_count, output1, batch_size);
995
996 local_sample_count =
997 static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
998 std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
999 test_function(false, local_sample_count, output2, batch_size);
1000 }
1001}
1002
1003TEST(DataTest, CanSaveAndLoadDistributedSequentialSampler) {
1004 {
1005 samplers::DistributedSequentialSampler a(10);
1006 ASSERT_EQ(a.index(), 0);
1007 std::stringstream stream;
1008 torch::save(a, stream);
1009
1010 samplers::DistributedSequentialSampler b(10);
1011 torch::load(b, stream);
1012 ASSERT_EQ(b.index(), 0);
1013 }
1014 {
1015 samplers::DistributedSequentialSampler a(10);
1016 a.next(3);
1017 a.next(4);
1018 ASSERT_EQ(a.index(), 7);
1019 std::stringstream stream;
1020 torch::save(a, stream);
1021
1022 samplers::DistributedSequentialSampler b(10);
1023 torch::load(b, stream);
1024 ASSERT_EQ(b.index(), 7);
1025 }
1026}
1027
1028TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
1029 DataLoaderOptions partial_options;
1030 FullDataLoaderOptions full_options(partial_options);
1031 ASSERT_EQ(full_options.batch_size, 1);
1032 ASSERT_FALSE(full_options.drop_last);
1033 ASSERT_EQ(full_options.workers, 0);
1034 ASSERT_EQ(full_options.max_jobs, 0);
1035 ASSERT_FALSE(full_options.timeout.has_value());
1036 ASSERT_TRUE(full_options.enforce_ordering);
1037}
1038
1039TEST(DataLoaderTest, DataLoaderOptionsCoalesceOptionalValues) {
1040 auto partial_options = DataLoaderOptions(32).workers(10);
1041 FullDataLoaderOptions full_options(partial_options);
1042 ASSERT_EQ(full_options.batch_size, 32);
1043 ASSERT_EQ(full_options.max_jobs, 2 * 10);
1044}
1045
1046TEST(DataLoaderTest, MakeDataLoaderDefaultsAsExpected) {
1047 auto data_loader = torch::data::make_data_loader(
1048 DummyDataset().map(transforms::Lambda<int>([](int x) { return x + 1; })));
1049 ASSERT_EQ(data_loader->options().batch_size, 1);
1050}
1051
1052struct UnsizedDataset : public datasets::Dataset<UnsizedDataset> {
1053 torch::data::Example<> get(size_t i) override {
1054 return {torch::ones(i), torch::ones(i)};
1055 }
1056 torch::optional<size_t> size() const noexcept override {
1057 return torch::nullopt;
1058 }
1059};
1060
1061TEST(
1062 DataLoaderTest,
1063 MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset) {
1064 ASSERT_THROWS_WITH(
1065 torch::data::make_data_loader(UnsizedDataset{}),
1066 "Expected the dataset to be sized in order to construct the Sampler");
1067}
1068
1069TEST(DataLoaderTest, IteratorsCompareEqualToThemselves) {
1070 auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1071 auto begin = data_loader->begin();
1072 ASSERT_EQ(begin, begin);
1073 auto end = data_loader->end();
1074 ASSERT_EQ(end, end);
1075}
1076
1077TEST(DataLoaderTest, ValidIteratorsCompareUnequalToEachOther) {
1078 auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1079 auto i = data_loader->begin();
1080 auto j = data_loader->begin();
1081 ASSERT_NE(i, j);
1082 ++j;
1083 ASSERT_NE(i, j);
1084}
1085
1086TEST(DataLoaderTest, SentinelIteratorsCompareEqualToEachOther) {
1087 auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1088 auto i = data_loader->end();
1089 auto j = data_loader->end();
1090 ASSERT_EQ(i, j);
1091}
1092
1093TEST(DataLoaderTest, IteratorsCompareEqualToSentinelWhenExhausted) {
1094 DummyDataset dataset;
1095 auto data_loader =
1096 torch::data::make_data_loader(dataset, dataset.size().value() / 4);
1097 auto i = data_loader->begin();
1098 auto end = data_loader->end();
1099 ASSERT_NE(i, end);
1100 ++i;
1101 ASSERT_NE(i, end);
1102 ++i;
1103 ASSERT_NE(i, end);
1104 ++i;
1105 ASSERT_NE(i, end);
1106 ++i;
1107 ASSERT_EQ(i, end);
1108}
1109
1110TEST(DataLoaderTest, IteratorsShareState) {
1111 DummyDataset dataset;
1112 auto data_loader =
1113 torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1114 auto i = data_loader->begin();
1115 auto j = i;
1116 auto end = data_loader->end();
1117 ASSERT_NE(i, end);
1118 ASSERT_NE(j, end);
1119 ++i;
1120 ASSERT_NE(i, end);
1121 ASSERT_NE(j, end);
1122 ++j;
1123 ASSERT_EQ(i, end);
1124 ASSERT_EQ(j, end);
1125}
1126
1127TEST(DataLoaderTest, CanDereferenceIteratorMultipleTimes) {
1128 DummyDataset dataset;
1129 auto data_loader =
1130 torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1131 dataset,
1132 // NOLINTNEXTLINE(bugprone-argument-comment)
1133 /*batch_size=*/1);
1134 auto iterator = data_loader->begin();
1135 std::vector<int> expected = {1};
1136 ASSERT_EQ(*iterator, expected);
1137 ASSERT_EQ(*iterator, expected);
1138 ++iterator;
1139 expected[0] = 2;
1140 ASSERT_EQ(*iterator, expected);
1141 ASSERT_EQ(*iterator, expected);
1142 ++iterator;
1143 expected[0] = 3;
1144 ASSERT_EQ(*iterator, expected);
1145 ASSERT_EQ(*iterator, expected);
1146}
1147
1148TEST(DataLoaderTest, CanUseIteratorAlgorithms) {
1149 struct D : datasets::BatchDataset<D, int> {
1150 int get_batch(torch::ArrayRef<size_t> indices) override {
1151 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1152 return 1 + indices.front();
1153 }
1154 torch::optional<size_t> size() const override {
1155 return 10;
1156 }
1157 };
1158
1159 D dataset;
1160 auto data_loader =
1161 torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1162 dataset, 1);
1163 std::vector<int> values;
1164 std::copy(
1165 data_loader->begin(), data_loader->end(), std::back_inserter(values));
1166 std::vector<int> expected(dataset.size().value());
1167 std::iota(expected.begin(), expected.end(), size_t(1));
1168 ASSERT_EQ(values, expected);
1169}
1170
1171TEST(DataLoaderTest, CallingBeginWhileOtherIteratorIsInFlightThrows) {
1172 DummyDataset dataset;
1173 auto data_loader =
1174 torch::data::make_data_loader(dataset, DataLoaderOptions(1).workers(2));
1175 auto i = data_loader->begin();
1176 ASSERT_THROWS_WITH(
1177 data_loader->begin(),
1178 "Attempted to get a new DataLoader iterator "
1179 "while another iterator is not yet exhausted");
1180}
1181
1182TEST(DataLoaderTest, IncrementingExhaustedValidIteratorThrows) {
1183 DummyDataset dataset;
1184 auto data_loader =
1185 torch::data::make_data_loader(dataset, dataset.size().value());
1186 auto i = data_loader->begin();
1187 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1188 ASSERT_NO_THROW(++i);
1189 ASSERT_THROWS_WITH(++i, "Attempted to increment iterator past the end");
1190}
1191
1192TEST(DataLoaderTest, DereferencingExhaustedValidIteratorThrows) {
1193 DummyDataset dataset;
1194 auto data_loader =
1195 torch::data::make_data_loader(dataset, dataset.size().value());
1196 auto i = data_loader->begin();
1197 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1198 ASSERT_NO_THROW(++i);
1199 ASSERT_THROWS_WITH(
1200 *i, "Attempted to dereference iterator that was past the end");
1201}
1202
1203TEST(DataLoaderTest, IncrementingSentinelIteratorThrows) {
1204 DummyDataset dataset;
1205 auto data_loader =
1206 torch::data::make_data_loader(dataset, dataset.size().value());
1207 auto i = data_loader->end();
1208 ASSERT_THROWS_WITH(
1209 ++i,
1210 "Incrementing the DataLoader's past-the-end iterator is not allowed");
1211}
1212
1213TEST(DataLoaderTest, DereferencingSentinelIteratorThrows) {
1214 DummyDataset dataset;
1215 auto data_loader =
1216 torch::data::make_data_loader(dataset, dataset.size().value());
1217 auto i = data_loader->end();
1218 ASSERT_THROWS_WITH(
1219 *i,
1220 "Dereferencing the DataLoader's past-the-end iterator is not allowed");
1221}
1222
1223TEST(DataLoaderTest, YieldsCorrectBatchSize) {
1224 DummyDataset dataset;
1225 auto data_loader = torch::data::make_data_loader(dataset, 25);
1226 auto iterator = data_loader->begin();
1227 ASSERT_EQ(iterator->size(), 25);
1228 ASSERT_EQ((++iterator)->size(), 25);
1229 ASSERT_EQ((++iterator)->size(), 25);
1230 ASSERT_EQ((++iterator)->size(), 25);
1231 ASSERT_EQ(++iterator, data_loader->end());
1232}
1233
1234TEST(
1235 DataLoaderTest,
1236 ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse) {
1237 DummyDataset dataset;
1238 auto data_loader = torch::data::make_data_loader(
1239 dataset, DataLoaderOptions(33).drop_last(false));
1240 auto iterator = data_loader->begin();
1241 ASSERT_EQ(iterator->size(), 33);
1242 ASSERT_EQ((++iterator)->size(), 33);
1243 ASSERT_EQ((++iterator)->size(), 33);
1244 ASSERT_EQ((++iterator)->size(), 1);
1245 ASSERT_EQ(++iterator, data_loader->end());
1246}
1247
1248TEST(
1249 DataLoaderTest,
1250 DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue) {
1251 DummyDataset dataset;
1252 auto data_loader = torch::data::make_data_loader(
1253 dataset, DataLoaderOptions(33).drop_last(true));
1254 auto iterator = data_loader->begin();
1255 ASSERT_EQ(iterator->size(), 33);
1256 ASSERT_EQ((++iterator)->size(), 33);
1257 ASSERT_EQ((++iterator)->size(), 33);
1258 ASSERT_EQ(++iterator, data_loader->end());
1259}
1260
1261TEST(DataLoaderTest, RespectsTimeout) {
1262 struct Baton {
1263 std::condition_variable cv;
1264 std::mutex mutex;
1265 };
1266
1267 struct D : datasets::Dataset<DummyDataset, int> {
1268 D(std::shared_ptr<Baton> b) : baton(std::move(b)) {}
1269 int get(size_t index) override {
1270 std::unique_lock<std::mutex> lock(baton->mutex);
1271 baton->cv.wait_for(lock, 1000 * kMillisecond);
1272 return 0;
1273 }
1274 torch::optional<size_t> size() const override {
1275 return 100;
1276 }
1277 std::shared_ptr<Baton> baton;
1278 };
1279
1280 auto baton = std::make_shared<Baton>();
1281
1282 auto data_loader = torch::data::make_data_loader(
1283 D{baton}, DataLoaderOptions().workers(1).timeout(10 * kMillisecond));
1284
1285 auto start = std::chrono::system_clock::now();
1286
1287 ASSERT_THROWS_WITH(*data_loader->begin(), "Timeout");
1288 baton->cv.notify_one();
1289
1290 auto end = std::chrono::system_clock::now();
1291 auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
1292 ASSERT_LT(duration.count(), 1);
1293}
1294
1295// stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
1296struct Barrier {
1297 explicit Barrier(size_t target) : counter_(target) {}
1298 void wait() {
1299 std::unique_lock<std::mutex> lock(mutex_);
1300 if (--counter_ == 0) {
1301 cv_.notify_all();
1302 } else {
1303 cv_.wait(lock, [this] { return this->counter_ == 0; });
1304 }
1305 }
1306
1307 size_t counter_;
1308 std::condition_variable cv_;
1309 std::mutex mutex_;
1310};
1311
1312// On the OrderingTest: This test is intended to verify that the
1313// `enforce_ordering` option of the dataloader works correctly. The reason this
1314// flag exists is because when the dataloader has multiple workers (threads)
1315// enabled and this flag is not set, the order in which worker threads finish
1316// loading their respective batch and push it back to the dataloader's main
1317// thread (for outside consumption) is not deterministic. Imagine the sampler is
1318// a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
1319// will be a single "job". Inside the dataloader, worker threads block until a
1320// job is available. It is not deterministic which worker thread wakes up batch
1321// to dequeue a particular batch. Further, some worker threads may take longer
1322// than others to read the data for their index. As such, it could be that
1323// worker thread 2 finishes before all other threads and returns its batch to
1324// the main thread. In that case, the dataloader iterator would return the datum
1325// at index 2 batch, and afterwards the datum from whatever thread finishes
1326// next. As such, the user may see data from indices 2, 0, 3, 1. On another run
1327// of the same dataloader on the same data, threads may be scheduled differently
1328// and return in order 0, 2, 3, 1. To force this ordering to deterministically
1329// be 0, 1, 2, 3, the `enforce_ordering` flag can be set to true. In that case,
1330// the dataloader will use a *sequencer* internally which keeps track of which
1331// datum is expected next, and buffers any other results until that next
1332// expected value arrives. For example, workers 1, 2, 3 may finish before worker
1333// 0. If `enforce_ordering` is true, the sequencer will internally buffer the
1334// results from 1, 2, 3 until worker 0 finishes. Only then does the dataloader
1335// return the datum from worker 0 to the user (and then datum 1 the next time,
1336// then 2 and so on).
1337//
1338// The way the test works is that we start
1339// `kNumberOfWorkers` workers in the dataloader, which each get an index from a
1340// `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
1341// has a copy of the dataset, and thus `get_batch()` is called on the
1342// thread-local copy in each worker. We want to simulate out-of-order completion
1343// of these threads. For this, we batch set a barrier in the `get_batch()`
1344// method to make sure every worker has some index to fetch assigned. Further,
1345// each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
1346// There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
1347// which we want the worker threads to return. For this, an iterator into this
1348// order is maintained. When the derferenced iterator (the current order index)
1349// matches the thread ID of a worker, it knows it can now return its index as
1350// well as progress the iterator. Inside the dataloader, the sequencer should
1351// buffer these indices such that they are ultimately returned in order.
1352
1353namespace ordering_test {
1354namespace {
1355const size_t kNumberOfWorkers = 10;
1356const std::vector<size_t> kOrderInWhichWorkersReturnTheirBatch =
1357 {3, 7, 0, 5, 4, 8, 2, 1, 9, 6};
1358} // namespace
1359
1360struct Dataset : datasets::BatchDataset<Dataset, size_t> {
1361 Dataset() = default;
1362
1363 // This copy constructor will be called when we copy the dataset into a
1364 // particular thread.
1365 Dataset(const Dataset& other) {
1366 static std::atomic<size_t> counter{0};
1367 thread_id_ = counter.fetch_add(1);
1368 }
1369
1370 Dataset(Dataset&& other) noexcept = default;
1371 Dataset& operator=(const Dataset& other) = delete;
1372 Dataset& operator=(Dataset&& other) noexcept = delete;
1373
1374 size_t get_batch(torch::ArrayRef<size_t> indices) override {
1375 static Barrier barrier(kNumberOfWorkers);
1376 static auto order_iterator = kOrderInWhichWorkersReturnTheirBatch.begin();
1377 static std::condition_variable cv;
1378 static std::mutex mutex;
1379
1380 // Wait for all threads to get an index batch and arrive here.
1381 barrier.wait();
1382
1383 std::unique_lock<std::mutex> lock(mutex);
1384 cv.wait(lock, [this] { return *order_iterator == this->thread_id_; });
1385 ++order_iterator;
1386 lock.unlock();
1387 cv.notify_all();
1388
1389 return indices.front();
1390 }
1391
1392 torch::optional<size_t> size() const override {
1393 return kNumberOfWorkers;
1394 }
1395
1396 size_t thread_id_ = 0;
1397};
1398
1399} // namespace ordering_test
1400
1401TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
1402 auto data_loader = torch::data::make_data_loader(
1403 ordering_test::Dataset{},
1404 torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers),
1405 DataLoaderOptions()
1406 .batch_size(1)
1407 .workers(ordering_test::kNumberOfWorkers)
1408 .enforce_ordering(true));
1409 std::vector<size_t> output;
1410 for (size_t value : *data_loader) {
1411 output.push_back(value);
1412 }
1413 std::vector<size_t> expected(ordering_test::kNumberOfWorkers);
1414 std::iota(expected.begin(), expected.end(), size_t(0));
1415 ASSERT_EQ(expected, output);
1416}
1417
1418TEST(DataLoaderTest, Reset) {
1419 DummyDataset dataset;
1420 auto data_loader =
1421 torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1422 auto end = data_loader->end();
1423
1424 auto iterator = data_loader->begin();
1425 ASSERT_NE(iterator, end);
1426 ASSERT_NE(++iterator, end);
1427 ASSERT_EQ(++iterator, end);
1428
1429 iterator = data_loader->begin();
1430 ASSERT_NE(iterator, end);
1431 ASSERT_NE(++iterator, end);
1432 ASSERT_EQ(++iterator, end);
1433
1434 iterator = data_loader->begin();
1435 ASSERT_NE(iterator, end);
1436 ASSERT_NE(++iterator, end);
1437 ASSERT_EQ(++iterator, end);
1438}
1439
1440TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
1441 struct D : datasets::Dataset<DummyDataset, int> {
1442 int get(size_t index) override {
1443 throw std::invalid_argument("badness");
1444 }
1445 torch::optional<size_t> size() const override {
1446 return 100;
1447 }
1448 };
1449
1450 auto data_loader = torch::data::make_data_loader(
1451 D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2));
1452 auto iterator = data_loader->begin();
1453
1454 try {
1455 (void)*iterator;
1456 } catch (torch::data::WorkerException& e) {
1457 ASSERT_EQ(
1458 e.what(),
1459 std::string("Caught exception in DataLoader worker thread. "
1460 "Original message: badness"));
1461 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
1462 ASSERT_THROW(
1463 std::rethrow_exception(e.original_exception), std::invalid_argument);
1464 }
1465}
1466
1467TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
1468 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1469
1470 struct D : datasets::StatefulDataset<D, int, size_t> {
1471 torch::optional<int> get_batch(size_t) override {
1472 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1473 return counter++;
1474 }
1475 return torch::nullopt;
1476 }
1477 torch::optional<size_t> size() const override {
1478 return 100;
1479 }
1480 void reset() override {
1481 counter = 0;
1482 }
1483 void save(torch::serialize::OutputArchive& archive) const override{};
1484 void load(torch::serialize::InputArchive& archive) override {}
1485 int counter = 0;
1486 };
1487
1488 auto data_loader = torch::data::make_data_loader(D{});
1489
1490 for (const auto i : c10::irange(10)) {
1491 const auto number_of_iterations =
1492 std::distance(data_loader->begin(), data_loader->end());
1493 ASSERT_EQ(
1494 number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1495 << "epoch " << i;
1496 }
1497
1498 for (const int i : *data_loader) {
1499 ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1500 }
1501}
1502
1503TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
1504 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1505 const int kNumberOfWorkers = 4;
1506
1507 struct D : datasets::StatefulDataset<D, int, size_t> {
1508 torch::optional<int> get_batch(size_t) override {
1509 std::lock_guard<std::mutex> lock(mutex);
1510 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1511 return counter++;
1512 }
1513 return torch::nullopt;
1514 }
1515 torch::optional<size_t> size() const override {
1516 return 100;
1517 }
1518 void reset() override {
1519 counter = 0;
1520 }
1521 void save(torch::serialize::OutputArchive& archive) const override{};
1522 void load(torch::serialize::InputArchive& archive) override {}
1523 int counter = 0;
1524 std::mutex mutex;
1525 };
1526
1527 auto data_loader = torch::data::make_data_loader(
1528 torch::data::datasets::make_shared_dataset<D>(),
1529 DataLoaderOptions().workers(kNumberOfWorkers));
1530
1531 for (const auto i : c10::irange(10)) {
1532 const auto number_of_iterations =
1533 std::distance(data_loader->begin(), data_loader->end());
1534 ASSERT_EQ(
1535 number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1536 << "epoch " << i;
1537 }
1538
1539 for (const int i : *data_loader) {
1540 ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1541 }
1542}
1543
1544TEST(DataLoaderTest, StatefulDatasetWithMap) {
1545 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1546
1547 struct D : datasets::StatefulDataset<D, int, size_t> {
1548 torch::optional<int> get_batch(size_t) override {
1549 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1550 return counter++;
1551 }
1552 return torch::nullopt;
1553 }
1554 torch::optional<size_t> size() const override {
1555 return 100;
1556 }
1557 void reset() override {
1558 counter = 0;
1559 }
1560 void save(torch::serialize::OutputArchive& archive) const override{};
1561 void load(torch::serialize::InputArchive& archive) override {}
1562 int counter = 0;
1563 };
1564
1565 auto data_loader = torch::data::make_data_loader(
1566 D().map(transforms::BatchLambda<int, std::string>(
1567 [](int x) { return std::to_string(x); }))
1568 .map(transforms::BatchLambda<std::string, torch::Tensor>(
1569 [](const std::string& x) {
1570 return torch::tensor(static_cast<int64_t>(std::stoi(x)));
1571 })),
1572 DataLoaderOptions{});
1573
1574 for (const auto i : c10::irange(10)) {
1575 const auto number_of_iterations =
1576 std::distance(data_loader->begin(), data_loader->end());
1577 ASSERT_EQ(
1578 number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1579 << "epoch " << i;
1580 }
1581
1582 for (const torch::Tensor& t : *data_loader) {
1583 ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
1584 }
1585}
1586
1587TEST(DataLoaderTest, StatefulDatasetWithCollate) {
1588 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1589
1590 struct D : datasets::StatefulDataset<D> {
1591 torch::optional<std::vector<Example<>>> get_batch(
1592 size_t batch_size) override {
1593 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1594 counter += batch_size;
1595 std::vector<Example<>> batch(
1596 /*count=*/batch_size,
1597 Example<>{
1598 torch::ones(batch_size + 1), torch::zeros(batch_size - 1)});
1599 return batch;
1600 }
1601 return torch::nullopt;
1602 }
1603 torch::optional<size_t> size() const override {
1604 return 100;
1605 }
1606 void reset() override {
1607 counter = 0;
1608 }
1609 void save(torch::serialize::OutputArchive& archive) const override{};
1610 void load(torch::serialize::InputArchive& archive) override {}
1611 int counter = 0;
1612 };
1613
1614 auto d = D().map(transforms::Stack<Example<>>());
1615
1616 const size_t kBatchSize = 5;
1617
1618 // Notice that the `get_batch()` of the dataset returns a vector<Example>, but
1619 // the `Stack` collation stacks the tensors into one.
1620 torch::optional<Example<>> batch = d.get_batch(kBatchSize);
1621 ASSERT_TRUE(batch.has_value());
1622 ASSERT_EQ(batch->data.size(0), kBatchSize);
1623 ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
1624 ASSERT_EQ(batch->target.size(0), kBatchSize);
1625 ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
1626
1627 ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
1628 ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
1629}
1630
1631// This test tests the core function for iterate through a chunk dataset. It
1632// contains test cases with different parameter combination. (For example,
1633// different prefetch count, batch size and data loader worker count). It
1634// verifies the return batches size and content when the order is deterministic.
1635TEST(DataLoaderTest, ChunkDataSetGetBatch) {
1636 // different prefetch count for testing.
1637 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1638 const size_t prefetch_counts[] = {1, 2, 3, 4};
1639
1640 // different batch size for testing.
1641 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1642 const size_t batch_sizes[] = {5, 7};
1643
1644 // test with/without worker threads
1645 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1646 const size_t dataloader_worker_counts[] = {0, 2};
1647
1648 const size_t total_example_count = 35;
1649 DummyChunkDataReader data_reader;
1650 samplers::SequentialSampler sampler(0);
1651
1652 // test functionality across epoch boundary
1653 const int epoch_count = 2;
1654
1655 for (auto prefetch_count : prefetch_counts) {
1656 for (auto batch_size : batch_sizes) {
1657 for (auto dataloader_worker_count : dataloader_worker_counts) {
1658 datasets::SharedBatchDataset<datasets::ChunkDataset<
1659 DummyChunkDataReader,
1660 samplers::SequentialSampler,
1661 samplers::SequentialSampler>>
1662 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1663 DummyChunkDataReader,
1664 samplers::SequentialSampler,
1665 samplers::SequentialSampler>>(
1666 data_reader,
1667 sampler,
1668 sampler,
1669 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1670
1671 auto data_loader = torch::data::make_data_loader(
1672 dataset,
1673 DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1674
1675 for (const auto epoch_index : c10::irange(epoch_count)) {
1676 (void)epoch_index; // Suppress unused variable warning
1677 std::vector<bool> result(total_example_count, false);
1678 int iteration_count = 0;
1679 for (auto iterator = data_loader->begin();
1680 iterator != data_loader->end();
1681 ++iterator, ++iteration_count) {
1682 DummyChunkDataReader::BatchType& batch = *iterator;
1683 ASSERT_EQ(batch.size(), batch_size);
1684
1685 // When prefetch_count is equal to 1 and no worker thread, the batch
1686 // order is deterministic. So we can verify elements in each batch.
1687 if (prefetch_count == 1 && dataloader_worker_count == 0) {
1688 for (const auto j : c10::irange(batch_size)) {
1689 ASSERT_EQ(batch[j], iteration_count * batch_size + j);
1690 }
1691 }
1692 for (const auto j : c10::irange(batch_size)) {
1693 result[batch[j]] = true;
1694 }
1695 }
1696
1697 for (auto data : result) {
1698 ASSERT_EQ(data, true);
1699 }
1700 }
1701 }
1702 }
1703 }
1704}
1705
1706TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {
1707 const size_t prefetch_count = 1;
1708 const size_t batch_size = 5;
1709 const size_t requested_batch_size = 6;
1710
1711 DummyChunkDataReader data_reader;
1712 samplers::SequentialSampler sampler(0);
1713
1714 datasets::SharedBatchDataset<datasets::ChunkDataset<
1715 DummyChunkDataReader,
1716 samplers::SequentialSampler,
1717 samplers::SequentialSampler>>
1718 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1719 DummyChunkDataReader,
1720 samplers::SequentialSampler,
1721 samplers::SequentialSampler>>(
1722 data_reader,
1723 sampler,
1724 sampler,
1725 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1726
1727 auto data_loader = torch::data::make_data_loader(
1728 dataset, DataLoaderOptions(requested_batch_size).workers(0));
1729
1730 std::string exception_msg =
1731 "The requested batch size does not match with the initialized batch "
1732 "size.\n The requested batch size is 6, while the dataset is created"
1733 " with batch size equal to 5";
1734
1735 ASSERT_THROWS_WITH(*(data_loader->begin()), exception_msg);
1736}
1737
1738TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
1739 struct DummyEmptyChunkDataReader : datasets::ChunkDataReader<int> {
1740 public:
1741 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1742
1743 BatchType read_chunk(size_t chunk_index) override {
1744 return {};
1745 }
1746
1747 size_t chunk_count() override {
1748 return 1;
1749 };
1750
1751 void reset() override{};
1752 };
1753
1754 const size_t prefetch_count = 1;
1755 const size_t batch_size = 5;
1756 DummyEmptyChunkDataReader data_reader;
1757 samplers::SequentialSampler sampler(0);
1758
1759 datasets::SharedBatchDataset<datasets::ChunkDataset<
1760 DummyEmptyChunkDataReader,
1761 samplers::SequentialSampler,
1762 samplers::SequentialSampler>>
1763 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1764 DummyEmptyChunkDataReader,
1765 samplers::SequentialSampler,
1766 samplers::SequentialSampler>>(
1767 data_reader,
1768 sampler,
1769 sampler,
1770 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1771
1772 auto data_loader = torch::data::make_data_loader(
1773 dataset, DataLoaderOptions(batch_size).workers(0));
1774
1775 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1776 ++iterator) {
1777 ASSERT_EQ(iterator->size(), 0);
1778 }
1779}
1780
1781TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
1782 struct D : public datasets::ChunkDataReader<int> {
1783 public:
1784 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1785
1786 BatchType read_chunk(size_t chunk_index) override {
1787 BatchType batch_data(10, 0);
1788 return batch_data;
1789 }
1790
1791 size_t chunk_count() override {
1792 return 2;
1793 };
1794
1795 void reset() override{};
1796 };
1797
1798 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1799 const size_t batch_sizes[] = {17, 30};
1800 D data_reader;
1801 samplers::SequentialSampler sampler(0);
1802
1803 for (auto batch_size : batch_sizes) {
1804 datasets::SharedBatchDataset<datasets::ChunkDataset<
1805 D,
1806 samplers::SequentialSampler,
1807 samplers::SequentialSampler>>
1808 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1809 D,
1810 samplers::SequentialSampler,
1811 samplers::SequentialSampler>>(
1812 data_reader,
1813 sampler,
1814 sampler,
1815 datasets::ChunkDatasetOptions(1, batch_size));
1816
1817 auto data_loader = torch::data::make_data_loader(
1818 dataset, DataLoaderOptions(batch_size).workers(0));
1819
1820 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1821 ++iterator) {
1822 DummyChunkDataReader::BatchType batch = *iterator;
1823 auto batch_size = batch.size();
1824 if (batch_size == 17) {
1825 ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
1826 }
1827 if (batch_size == 30) {
1828 ASSERT_TRUE(batch.size() == 20);
1829 }
1830 }
1831 }
1832}
1833
1834TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
1835 const size_t prefetch_count = 2;
1836 const size_t batch_size = 5;
1837
1838 DummyChunkDataReader data_reader;
1839 samplers::SequentialSampler sampler(0);
1840 datasets::SharedBatchDataset<datasets::ChunkDataset<
1841 DummyChunkDataReader,
1842 samplers::SequentialSampler,
1843 samplers::SequentialSampler>>
1844 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1845 DummyChunkDataReader,
1846 samplers::SequentialSampler,
1847 samplers::SequentialSampler>>(
1848 data_reader,
1849 sampler,
1850 sampler,
1851 datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1852
1853 samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
1854
1855 auto data_loader = torch::data::make_data_loader(
1856 dataset.map(transforms::BatchLambda<
1857 DummyChunkDataReader::BatchType,
1858 DummyChunkDataReader::DataType>(
1859 [](DummyChunkDataReader::BatchType batch) {
1860 return std::accumulate(batch.begin(), batch.end(), 0);
1861 })),
1862 DataLoaderOptions(batch_size).workers(0));
1863
1864 // before we start, the index should be 0.
1865 ASSERT_EQ(chunk_sampler.index(), 0);
1866
1867 size_t sum = 0;
1868 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1869 ++iterator) {
1870 sum += *iterator;
1871 }
1872 ASSERT_EQ(sum, 595); // sum([0, 35))
1873 // 3 chunks, and when exhausted the value is already incremented.
1874 ASSERT_EQ(chunk_sampler.index(), 3);
1875}
1876
1877TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
1878 const size_t prefetch_count = 2;
1879 const size_t batch_size = 5;
1880 // this will make the preloaders to wait till the `get_batch()` calls.
1881 const size_t cache_size = 10;
1882
1883 DummyChunkDataReader data_reader;
1884 samplers::SequentialSampler sampler(0);
1885 datasets::SharedBatchDataset<datasets::ChunkDataset<
1886 DummyChunkDataReader,
1887 samplers::SequentialSampler,
1888 samplers::SequentialSampler>>
1889 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1890 DummyChunkDataReader,
1891 samplers::SequentialSampler,
1892 samplers::SequentialSampler>>(
1893 data_reader,
1894 sampler,
1895 sampler,
1896 datasets::ChunkDatasetOptions(
1897 prefetch_count, batch_size, cache_size));
1898
1899 auto data_loader = torch::data::make_data_loader(
1900 dataset.map(transforms::BatchLambda<
1901 DummyChunkDataReader::BatchType,
1902 DummyChunkDataReader::DataType>(
1903 [](DummyChunkDataReader::BatchType batch) {
1904 return std::accumulate(batch.begin(), batch.end(), 0);
1905 })),
1906 DataLoaderOptions(batch_size).workers(0));
1907 // simply creates the iterator but no iteration. chunk preloaders are waiting
1908 // to fill the batch buffer but it is not draining. Still we need to exit
1909 // cleanly.
1910 auto iterator = data_loader->begin();
1911}
1912
1913// Test ChunkDataset save function.
1914// Note [save/load ChunkDataset as ChunkSampler]:
1915// The chunk sampler inside ChunkDataset is used in a separate thread pool other
1916// than the main thread. Thus it is very hard to accurately estimate its status
1917// when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of
1918// testing, we utilize the implementation fact that the file format for sampler
1919// serialization is the same as ChunkDataset serialization, and manually control
1920// the chunk sampler by calling the sampler's save/load method for value
1921// validation. This is only for testing the specific save/load functionality. In
1922// real user case, the user should still use matching ChunkDataset::save and
1923// ChunkDataset::load method.
1924TEST(DataLoaderTest, ChunkDatasetSave) {
1925 const size_t chunk_count_ = 6;
1926 const size_t chunk_size = 10;
1927
1928 struct DummyTestChunkDataReader : datasets::ChunkDataReader<int> {
1929 public:
1930 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1931
1932 BatchType read_chunk(size_t chunk_index) override {
1933 return batch_data_;
1934 }
1935
1936 size_t chunk_count() override {
1937 return chunk_count_;
1938 };
1939
1940 void reset() override{};
1941 BatchType batch_data_ = BatchType(chunk_size, 0);
1942 };
1943
1944 const size_t prefetch_count = 1;
1945 const size_t batch_size = chunk_size;
1946 const size_t dataloader_worker_count = 0;
1947 samplers::SequentialSampler sampler(0);
1948 const int epoch_count = 2;
1949
1950 DummyTestChunkDataReader data_reader;
1951
1952 // tested save_intervals
1953 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
1954 const size_t save_intervals[] = {1, 2};
1955
1956 using datasets::ChunkDatasetOptions;
1957
1958 for (auto save_interval : save_intervals) {
1959 auto tempfile = c10::make_tempfile();
1960
1961 datasets::SharedBatchDataset<datasets::ChunkDataset<
1962 DummyTestChunkDataReader,
1963 samplers::SequentialSampler,
1964 samplers::SequentialSampler>>
1965 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1966 DummyTestChunkDataReader,
1967 samplers::SequentialSampler,
1968 samplers::SequentialSampler>>(
1969 data_reader,
1970 sampler,
1971 sampler,
1972 ChunkDatasetOptions(
1973 prefetch_count, batch_size, chunk_size /*cache size*/));
1974
1975 auto data_loader = torch::data::make_data_loader(
1976 dataset,
1977 DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1978
1979 for (const auto epoch_index : c10::irange(epoch_count)) {
1980 (void)epoch_index; // Suppress unused variable warning
1981 unsigned iteration_count = 0;
1982 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1983 ++iterator, ++iteration_count) {
1984 if ((iteration_count + 1) % save_interval == 0) {
1985 torch::save(*dataset, tempfile.name);
1986
1987 samplers::SequentialSampler new_sampler(0);
1988
1989 // See Note [save/load ChunkDataset as ChunkSampler]
1990 torch::load(new_sampler, tempfile.name);
1991
1992 // Verify save logic. For ChunkDataset, the chunk data is stored in a
1993 // cache inside the dataset. One pool of threads are constantly
1994 // writing to the cache, and a different pool of thread are constantly
1995 // reading from the cache. Due to the nature of asynchronization, at
1996 // the time of get_batch(), which chunk is written to the cache is not
1997 // fully deterministic.
1998 // But we can still calculate a restricted window on the expected
1999 // output, hence verify the logic. In this test, the cache size is
2000 // configured to be the same as chunk size and batch size. So the
2001 // chunk data is written to the cache one by one. Only the current
2002 // batch is retrieved, the next chunk is written. Now in iteration 0,
2003 // after the first batch is retrieved, when we save the dataset
2004 // statues, there are three possible scenarios for the writer thread:
2005 // 1. it hasn't started loading the next chunk data yet, so the
2006 // sequential sampler index is still 0;
2007 // 2. it started to load the second chunk, so the sequencial sampler
2008 // index is at 1;
2009 // 3. it finished loading the second chunk, and start to load the
2010 // third chunk, because the cache is still fully occupied by the data
2011 // from the second chunk, it is waiting to write to the cache. At this
2012 // point, the sampler index is at 2.
2013 // So now we have a window of [0, 2], which is what we expected the
2014 // sampler to save the index from. Now noted for sequential sampler,
2015 // it advances to the next index automatically in the call next(). So
2016 // when save the index, it saves the next index in stead of the
2017 // current one. In other word, after getting the first index from
2018 // sequential sampler, it already moves to the second index. So when
2019 // we save it, it is the second index we save. As a result,
2020 // we need to advance the window by one. Now we have the expected
2021 // window of [1, 3].
2022 // This analysis applies to all scenarios. So extend it to a more
2023 // general case: the expected saved index should falling into the
2024 // range of [iteration, iteration + 3], which is the validation
2025 // below.
2026 ASSERT_TRUE(
2027 new_sampler.index() >= iteration_count + 1 &&
2028 new_sampler.index() <= iteration_count + 3);
2029 }
2030 }
2031 }
2032 }
2033}
2034
2035// Test ChunkDataset load function.
2036TEST(DataLoaderTest, ChunkDatasetLoad) {
2037 auto tempfile = c10::make_tempfile();
2038
2039 const size_t prefetch_count = 1;
2040 const size_t batch_size = 10;
2041 const size_t dataloader_worker_count = 0;
2042
2043 DummyChunkDataReader data_reader;
2044 samplers::SequentialSampler sampler(0);
2045
2046 const size_t skipped_chunk = 2;
2047
2048 // Configure sampler to skip 2 chunks
2049 {
2050 sampler.reset(data_reader.chunk_count());
2051 sampler.next(skipped_chunk);
2052
2053 // See Note [save/load ChunkDataset as ChunkSampler]
2054 torch::save(sampler, tempfile.name);
2055 }
2056
2057 // test functionality across epoch boundary. The first epoch should be
2058 // affected by the checkpoint, but the second should start normally.
2059 const int epoch_count = 2;
2060
2061 datasets::SharedBatchDataset<datasets::ChunkDataset<
2062 DummyChunkDataReader,
2063 samplers::SequentialSampler,
2064 samplers::SequentialSampler>>
2065 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2066 DummyChunkDataReader,
2067 samplers::SequentialSampler,
2068 samplers::SequentialSampler>>(
2069 data_reader,
2070 sampler,
2071 sampler,
2072 datasets::ChunkDatasetOptions(
2073 prefetch_count, batch_size, 20 /*cache size*/));
2074
2075 torch::load(*dataset, tempfile.name);
2076
2077 auto data_loader = torch::data::make_data_loader(
2078 dataset, DataLoaderOptions(batch_size).workers(dataloader_worker_count));
2079
2080 for (const auto epoch_index : c10::irange(epoch_count)) {
2081 int iteration_count = 0;
2082
2083 // For the first epoch, the returned batch should be returned from the
2084 // third chunk, because the check point skipped the first two chunks. But
2085 // for the next epoch, it should start from the first batch.
2086 int initial_value = epoch_index == 0 ? 15 : 0;
2087
2088 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2089 ++iterator, ++iteration_count) {
2090 DummyChunkDataReader::BatchType batch = *iterator;
2091
2092 std::vector<int> expected_result;
2093 size_t expected_size = (epoch_index > 0 && iteration_count == 3) ? 5 : 10;
2094 expected_result.resize(expected_size);
2095 std::iota(expected_result.begin(), expected_result.end(), initial_value);
2096
2097 ASSERT_EQ(batch.size(), expected_result.size());
2098 ASSERT_TRUE(
2099 std::equal(batch.begin(), batch.end(), expected_result.begin()));
2100
2101 initial_value += batch_size;
2102 }
2103 }
2104
2105 samplers::SequentialSampler new_sampler(0);
2106
2107 // See Note [save/load ChunkDataset as ChunkSampler]
2108 torch::load(new_sampler, tempfile.name);
2109
2110 ASSERT_EQ(new_sampler.index(), skipped_chunk);
2111}
2112
2113TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
2114 const size_t chunk_size = 5;
2115 const size_t batch_size = 5;
2116
2117 class S : public samplers::Sampler<> {
2118 public:
2119 explicit S(size_t size) : size_(size), index_(0){};
2120
2121 void reset(torch::optional<size_t> new_size = torch::nullopt) override {
2122 if (new_size.has_value()) {
2123 size_ = *new_size;
2124 }
2125 indices_.resize(size_);
2126 size_t index = 0;
2127
2128 // Repeatly sample every 5 indices.
2129 for (const auto i : c10::irange(batch_size)) {
2130 for (size_t j = 0; j < size_ / batch_size; ++j) {
2131 indices_[index++] = i + batch_size * j;
2132 }
2133 }
2134 index_ = 0;
2135 }
2136
2137 // Returns the next batch of indices.
2138 torch::optional<std::vector<size_t>> next(size_t batch_size) override {
2139 const auto remaining_indices = size_ - index_;
2140 if (remaining_indices == 0) {
2141 return torch::nullopt;
2142 }
2143 auto return_size = std::min(batch_size, remaining_indices);
2144 std::vector<size_t> index_batch(
2145 indices_.begin() + index_, indices_.begin() + index_ + return_size);
2146 index_ += return_size;
2147
2148 return index_batch;
2149 }
2150
2151 void save(torch::serialize::OutputArchive& archive) const override {}
2152 void load(torch::serialize::InputArchive& archive) override {}
2153
2154 private:
2155 size_t size_;
2156 std::vector<size_t> indices_;
2157 size_t index_{0};
2158 };
2159
2160 struct D : public datasets::ChunkDataReader<int> {
2161 public:
2162 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
2163 D(size_t chunk_count) : chunk_count_(chunk_count) {}
2164
2165 BatchType read_chunk(size_t chunk_index) override {
2166 BatchType batch_data(chunk_size, chunk_index);
2167 return batch_data;
2168 }
2169
2170 size_t chunk_count() override {
2171 return chunk_count_;
2172 };
2173
2174 void reset() override{};
2175 size_t chunk_count_;
2176 };
2177
2178 const size_t prefetch_count = 1;
2179 const size_t cache_size = 10;
2180 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2181 const size_t cross_chunk_shuffle_counts[] = {2, 3};
2182 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2183 const size_t chunk_counts[] = {3, 4, 5};
2184
2185 samplers::SequentialSampler chunk_sampler(0);
2186 S example_sampler(0);
2187
2188 for (auto chunk_count : chunk_counts) {
2189 for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
2190 D data_reader(chunk_count);
2191
2192 datasets::SharedBatchDataset<
2193 datasets::ChunkDataset<D, samplers::SequentialSampler, S>>
2194 dataset = datasets::make_shared_dataset<
2195 datasets::ChunkDataset<D, samplers::SequentialSampler, S>>(
2196 data_reader,
2197 chunk_sampler,
2198 example_sampler,
2199 datasets::ChunkDatasetOptions(
2200 prefetch_count,
2201 batch_size,
2202 cache_size,
2203 cross_chunk_shuffle_count));
2204
2205 auto data_loader = torch::data::make_data_loader(
2206 dataset, DataLoaderOptions(batch_size).workers(0));
2207
2208 std::vector<int> result;
2209 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2210 ++iterator) {
2211 auto batch_result = *iterator;
2212 std::copy(
2213 batch_result.begin(),
2214 batch_result.end(),
2215 std::back_inserter(result));
2216 }
2217
2218 std::vector<int> expected_result;
2219 {
2220 // construct expected result
2221 for (const auto i : c10::irange(
2222 (chunk_count + cross_chunk_shuffle_count - 1) /
2223 cross_chunk_shuffle_count)) {
2224 for (const auto j : c10::irange(chunk_size)) {
2225 (void)j; // Suppress unused variable warning
2226 for (const auto k : c10::irange(cross_chunk_shuffle_count)) {
2227 if (i * cross_chunk_shuffle_count + k < chunk_count) {
2228 expected_result.push_back(i * cross_chunk_shuffle_count + k);
2229 }
2230 }
2231 }
2232 }
2233 }
2234
2235 ASSERT_EQ(result.size(), expected_result.size());
2236 ASSERT_TRUE(
2237 std::equal(result.begin(), result.end(), expected_result.begin()));
2238 }
2239 }
2240}
2241
2242TEST(DataLoaderTest, CustomPreprocessPolicy) {
2243 const size_t chunk_size = 5;
2244 const size_t batch_size = 10;
2245
2246 struct D : public datasets::ChunkDataReader<int> {
2247 public:
2248 using BatchType = datasets::ChunkDataReader<int>::ChunkType;
2249 D(size_t chunk_count) : chunk_count_(chunk_count) {}
2250
2251 BatchType read_chunk(size_t chunk_index) override {
2252 BatchType batch_data(chunk_size);
2253 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand)
2254 auto rand_gen = []() { return std::rand() % 100; };
2255 std::generate(batch_data.begin(), batch_data.end(), rand_gen);
2256 return batch_data;
2257 }
2258
2259 size_t chunk_count() override {
2260 return chunk_count_;
2261 };
2262
2263 void reset() override{};
2264 size_t chunk_count_;
2265 };
2266
2267 // custom preprocessing policy - sort the data ascendingly
2268 auto sorting_policy = [](std::vector<int>& raw_batch_data) {
2269 std::sort(raw_batch_data.begin(), raw_batch_data.end());
2270 };
2271 std::function<void(std::vector<int>&)> policy_function = sorting_policy;
2272
2273 const size_t prefetch_count = 1;
2274 const size_t cache_size = 10;
2275 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2276 const size_t cross_chunk_shuffle_counts[] = {1, 2};
2277 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
2278 const size_t chunk_counts[] = {3, 4};
2279
2280 samplers::SequentialSampler chunk_sampler(0);
2281
2282 for (auto chunk_count : chunk_counts) {
2283 for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
2284 D data_reader(chunk_count);
2285
2286 datasets::SharedBatchDataset<datasets::ChunkDataset<
2287 D,
2288 samplers::SequentialSampler,
2289 samplers::SequentialSampler>>
2290 dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2291 D,
2292 samplers::SequentialSampler,
2293 samplers::SequentialSampler>>(
2294 data_reader,
2295 chunk_sampler,
2296 chunk_sampler,
2297 datasets::ChunkDatasetOptions(
2298 prefetch_count,
2299 batch_size,
2300 cache_size,
2301 cross_chunk_shuffle_count),
2302 policy_function);
2303
2304 auto data_loader = torch::data::make_data_loader(
2305 dataset, DataLoaderOptions(batch_size).workers(0));
2306
2307 std::vector<int> result;
2308 for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2309 ++iterator) {
2310 auto batch_result = *iterator;
2311 if (batch_result.size() > chunk_size * cross_chunk_shuffle_count) {
2312 for (unsigned i = 0; i < batch_result.size(); i += chunk_size) {
2313 ASSERT_TRUE(std::is_sorted(
2314 batch_result.begin() + i,
2315 batch_result.begin() + i + chunk_size));
2316 }
2317 } else {
2318 ASSERT_TRUE(std::is_sorted(batch_result.begin(), batch_result.end()));
2319 }
2320 }
2321 }
2322 }
2323}
2324