1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
17#define TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
18
19#include <atomic>
20#include <random>
21#include <string>
22#include <vector>
23
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/lib/core/errors.h"
26#include "tensorflow/core/lib/core/notification.h"
27#include "tensorflow/core/lib/core/threadpool.h"
28#include "tensorflow/core/platform/macros.h"
29#include "tensorflow/core/platform/thread_annotations.h"
30
31namespace tensorflow {
32
33// RecordYielder produces value records from a set of tfrecord files
34// in a random order.
35//
36// It guarantees that:
37// 1) all records in tfrecords are yielded within every epoch;
38// 2) each record is yielded only once within every epoch;
39// 3) the order in which records are yielded is highly randomized.
40// 4) the peak memory usage is roughly avg record size *
41// (opts.bufsize + opts.parallelism * 16).
42//
43// Usage example:
44// RecordYielder::Options opts;
45// opts.file_pattern = "input-*";
46// opts.seed = 301;
47// opts.bufsize = 1000000; // A randomized buffer with 1M records.
48// opts.parallelism = 8; // Uses 8 tfrecord iterators to iterate
49// // through all files.
50// RecordYielder yielder(opts);
51// string val;
52// while (true) {
53// yielder.YieldOne(&val);
54// // process val
55// }
56//
57// RecordYielder can be accessed by multiple threads concurrently.
58class RecordYielder {
59 public:
60 struct Options {
61 // Glob pattern for tfrecords.
62 string file_pattern;
63
64 // Random seed. It determines how data files are shuffled and how
65 // records are shuffled.
66 int64_t seed = 0;
67
68 // Each epoch, all files are first shuffled according to the
69 // random seed and the epoch number, and then all files are
70 // left-shifted by file_shuffle_shift_ratio * num_files slots. If
71 // file_shuffle_shift_ratio is not within [0, 1), the
72 // implementation clip it to [0, 1).
73 float file_shuffle_shift_ratio = 0;
74
75 // Randomization buffer keeps these many records.
76 uint64 bufsize = 1;
77
78 // Uses these many concurrent tfrecord iterators to iterate through
79 // tfrecords.
80 int32 parallelism = 1;
81
82 string compression_type;
83 };
84
85 explicit RecordYielder(OpKernelConstruction* context,
86 const RecordYielder::Options& opts);
87 ~RecordYielder();
88
89 RecordYielder(const RecordYielder&) = delete;
90 RecordYielder& operator=(const RecordYielder&) = delete;
91
92 // Yields one 'value'.
93 Status YieldOne(tstring* value);
94
95 // Returns the current epoch number.
96 int64_t current_epoch() const { return epoch_; }
97
98 private:
99 typedef RecordYielder ME;
100
101 Options opts_;
102
103 // Backgrounds threads. Owned.
104 thread::ThreadPool* thread_;
105
106 // Epoch number.
107 std::atomic<int64_t> epoch_;
108
109 mutex mu_;
110
111 // Turned to true when this is deleted.
112 bool stop_ TF_GUARDED_BY(mu_) = false;
113 Status status_ TF_GUARDED_BY(mu_);
114
115 // PRG used for randomization.
116 std::mt19937_64 rnd_ TF_GUARDED_BY(mu_);
117
118 // Randomization buffer.
119 std::vector<string> buf_ TF_GUARDED_BY(mu_);
120
121 // True iff we are draining an epoch.
122 bool epoch_end_ = false;
123
124 int64_t num_records_added_in_epoch_ = 0;
125 int64_t num_records_yielded_in_epoch_ = 0;
126
127 // Trigger when the main loop has exited.
128 Notification main_loop_done_;
129
130 // condition_variables.
131 condition_variable buf_empty_;
132 bool BufEmpty() const TF_SHARED_LOCKS_REQUIRED(mu_) {
133 return stop_ || buf_.empty();
134 }
135
136 condition_variable buf_not_full_;
137 bool BufNotFull() const TF_SHARED_LOCKS_REQUIRED(mu_) {
138 return stop_ || buf_.size() < opts_.bufsize;
139 }
140
141 condition_variable buf_enough_;
142 bool BufEnough() const TF_SHARED_LOCKS_REQUIRED(mu_) {
143 // NOTE: Unless we are finishing an epoch, we want to make sure
144 // the buf_ contains enough randomized elements before yielding
145 // any.
146 return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) ||
147 (!epoch_end_ &&
148 buf_.size() >= std::max<uint64>(1, opts_.bufsize / 2));
149 }
150
151 void MainLoop();
152 struct Shard;
153 void ShardLoop(Shard* shard);
154 bool ShouldFinish(const Status& s);
155 bool Add(std::vector<string>* values);
156};
157
158} // namespace tensorflow
159
160#endif // TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
161