1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_ |
18 | |
19 | #include <atomic> |
20 | #include <random> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/lib/core/errors.h" |
26 | #include "tensorflow/core/lib/core/notification.h" |
27 | #include "tensorflow/core/lib/core/threadpool.h" |
28 | #include "tensorflow/core/platform/macros.h" |
29 | #include "tensorflow/core/platform/thread_annotations.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | // RecordYielder produces value records from a set of tfrecord files |
34 | // in a random order. |
35 | // |
36 | // It guarantees that: |
37 | // 1) all records in tfrecords are yielded within every epoch; |
38 | // 2) each record is yielded only once within every epoch; |
39 | // 3) the order in which records are yielded is highly randomized. |
40 | // 4) the peak memory usage is roughly avg record size * |
41 | // (opts.bufsize + opts.parallelism * 16). |
42 | // |
43 | // Usage example: |
44 | // RecordYielder::Options opts; |
45 | // opts.file_pattern = "input-*"; |
46 | // opts.seed = 301; |
47 | // opts.bufsize = 1000000; // A randomized buffer with 1M records. |
48 | // opts.parallelism = 8; // Uses 8 tfrecord iterators to iterate |
49 | // // through all files. |
50 | // RecordYielder yielder(opts); |
51 | // string val; |
52 | // while (true) { |
53 | // yielder.YieldOne(&val); |
54 | // // process val |
55 | // } |
56 | // |
57 | // RecordYielder can be accessed by multiple threads concurrently. |
58 | class RecordYielder { |
59 | public: |
60 | struct Options { |
61 | // Glob pattern for tfrecords. |
62 | string file_pattern; |
63 | |
64 | // Random seed. It determines how data files are shuffled and how |
65 | // records are shuffled. |
66 | int64_t seed = 0; |
67 | |
68 | // Each epoch, all files are first shuffled according to the |
69 | // random seed and the epoch number, and then all files are |
70 | // left-shifted by file_shuffle_shift_ratio * num_files slots. If |
71 | // file_shuffle_shift_ratio is not within [0, 1), the |
72 | // implementation clip it to [0, 1). |
73 | float file_shuffle_shift_ratio = 0; |
74 | |
75 | // Randomization buffer keeps these many records. |
76 | uint64 bufsize = 1; |
77 | |
78 | // Uses these many concurrent tfrecord iterators to iterate through |
79 | // tfrecords. |
80 | int32 parallelism = 1; |
81 | |
82 | string compression_type; |
83 | }; |
84 | |
85 | explicit RecordYielder(OpKernelConstruction* context, |
86 | const RecordYielder::Options& opts); |
87 | ~RecordYielder(); |
88 | |
89 | RecordYielder(const RecordYielder&) = delete; |
90 | RecordYielder& operator=(const RecordYielder&) = delete; |
91 | |
92 | // Yields one 'value'. |
93 | Status YieldOne(tstring* value); |
94 | |
95 | // Returns the current epoch number. |
96 | int64_t current_epoch() const { return epoch_; } |
97 | |
98 | private: |
99 | typedef RecordYielder ME; |
100 | |
101 | Options opts_; |
102 | |
103 | // Backgrounds threads. Owned. |
104 | thread::ThreadPool* thread_; |
105 | |
106 | // Epoch number. |
107 | std::atomic<int64_t> epoch_; |
108 | |
109 | mutex mu_; |
110 | |
111 | // Turned to true when this is deleted. |
112 | bool stop_ TF_GUARDED_BY(mu_) = false; |
113 | Status status_ TF_GUARDED_BY(mu_); |
114 | |
115 | // PRG used for randomization. |
116 | std::mt19937_64 rnd_ TF_GUARDED_BY(mu_); |
117 | |
118 | // Randomization buffer. |
119 | std::vector<string> buf_ TF_GUARDED_BY(mu_); |
120 | |
121 | // True iff we are draining an epoch. |
122 | bool epoch_end_ = false; |
123 | |
124 | int64_t num_records_added_in_epoch_ = 0; |
125 | int64_t num_records_yielded_in_epoch_ = 0; |
126 | |
127 | // Trigger when the main loop has exited. |
128 | Notification main_loop_done_; |
129 | |
130 | // condition_variables. |
131 | condition_variable buf_empty_; |
132 | bool BufEmpty() const TF_SHARED_LOCKS_REQUIRED(mu_) { |
133 | return stop_ || buf_.empty(); |
134 | } |
135 | |
136 | condition_variable buf_not_full_; |
137 | bool BufNotFull() const TF_SHARED_LOCKS_REQUIRED(mu_) { |
138 | return stop_ || buf_.size() < opts_.bufsize; |
139 | } |
140 | |
141 | condition_variable buf_enough_; |
142 | bool BufEnough() const TF_SHARED_LOCKS_REQUIRED(mu_) { |
143 | // NOTE: Unless we are finishing an epoch, we want to make sure |
144 | // the buf_ contains enough randomized elements before yielding |
145 | // any. |
146 | return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) || |
147 | (!epoch_end_ && |
148 | buf_.size() >= std::max<uint64>(1, opts_.bufsize / 2)); |
149 | } |
150 | |
151 | void MainLoop(); |
152 | struct Shard; |
153 | void ShardLoop(Shard* shard); |
154 | bool ShouldFinish(const Status& s); |
155 | bool Add(std::vector<string>* values); |
156 | }; |
157 | |
158 | } // namespace tensorflow |
159 | |
160 | #endif // TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_ |
161 | |