1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/record_yielder.h" |
17 | |
18 | #include "tensorflow/core/lib/io/record_reader.h" |
19 | #include "tensorflow/core/lib/strings/str_util.h" |
20 | #include "tensorflow/core/platform/env.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | RecordYielder::RecordYielder(OpKernelConstruction* context, |
25 | const RecordYielder::Options& opts) |
26 | : opts_(opts), |
27 | thread_(new thread::ThreadPool(context->env(), ThreadOptions(), |
28 | "record_yielder" , 1 + opts.parallelism, |
29 | /* low_latency_hint */ false)), |
30 | epoch_(0), |
31 | rnd_(opts.seed) { |
32 | thread_->Schedule([this]() { MainLoop(); }); |
33 | } |
34 | |
35 | RecordYielder::~RecordYielder() { |
36 | { |
37 | mutex_lock l(mu_); |
38 | stop_ = true; |
39 | buf_empty_.notify_all(); |
40 | buf_enough_.notify_all(); |
41 | buf_not_full_.notify_all(); |
42 | } |
43 | main_loop_done_.WaitForNotification(); |
44 | delete thread_; |
45 | } |
46 | |
47 | Status RecordYielder::YieldOne(tstring* value) { |
48 | mutex_lock l(mu_); |
49 | while (!BufEnough() && status_.ok()) { |
50 | buf_enough_.wait(l); |
51 | } |
52 | if (status_.ok()) { |
53 | bool notify_no_longer_full = !BufNotFull(); |
54 | CHECK(!stop_ && !buf_.empty()); |
55 | *value = std::move(buf_.back()); |
56 | buf_.pop_back(); |
57 | ++num_records_yielded_in_epoch_; |
58 | // Assumption is that an epoch always has something in the buffer |
59 | // until it ends. If the input pipeline was slower than the consumers |
60 | // by a lot this might not be true. Not sure how to handle. |
61 | if (buf_.empty()) { |
62 | buf_empty_.notify_all(); |
63 | } |
64 | if (notify_no_longer_full) { |
65 | buf_not_full_.notify_all(); |
66 | } |
67 | } |
68 | return status_; |
69 | } |
70 | |
71 | struct RecordYielder::Shard { |
72 | int index; // Shard index. |
73 | std::vector<tstring> filenames; // File names given to this shard. |
74 | Notification done; // Notified when this shard is done. |
75 | Status status; // Shard status. |
76 | }; |
77 | |
78 | bool RecordYielder::ShouldFinish(const Status& s) { |
79 | mutex_lock l(mu_); |
80 | status_.Update(s); |
81 | return stop_ || !status_.ok(); |
82 | } |
83 | |
84 | static Status MatchFiles(const string& patterns, |
85 | std::vector<string>* filenames) { |
86 | for (const auto& file_pattern : str_util::Split(patterns, ',')) { |
87 | std::vector<string> tmp_filenames; |
88 | TF_RETURN_IF_ERROR( |
89 | Env::Default()->GetMatchingPaths(file_pattern, &tmp_filenames)); |
90 | filenames->insert(filenames->end(), |
91 | std::make_move_iterator(tmp_filenames.begin()), |
92 | std::make_move_iterator(tmp_filenames.end())); |
93 | } |
94 | return OkStatus(); |
95 | } |
96 | |
97 | void RecordYielder::MainLoop() { |
98 | while (true) { |
99 | ++epoch_; |
100 | num_records_yielded_in_epoch_ = 0; |
101 | num_records_added_in_epoch_ = 0; |
102 | |
103 | // Finds all files. |
104 | std::vector<string> filenames; |
105 | Status s = MatchFiles(opts_.file_pattern, &filenames); |
106 | |
107 | if (filenames.empty()) { |
108 | s = errors::NotFound("Found no files at " , opts_.file_pattern); |
109 | if (ShouldFinish(s)) { |
110 | buf_enough_.notify_all(); |
111 | break; |
112 | } |
113 | } |
114 | |
115 | if (ShouldFinish(s)) break; |
116 | |
117 | // Shuffles these files according to the epoch # and random seed. |
118 | std::mt19937_64 shuffle_rnd( |
119 | Hash64(reinterpret_cast<char*>(&epoch_), sizeof(epoch_), opts_.seed)); |
120 | std::shuffle(filenames.begin(), filenames.end(), shuffle_rnd); |
121 | |
122 | // Left-shift the filename list. |
123 | const std::vector<string>::size_type num = filenames.size(); |
124 | int64_t shift; |
125 | if (0 <= opts_.file_shuffle_shift_ratio && |
126 | opts_.file_shuffle_shift_ratio < 1) { |
127 | shift = opts_.file_shuffle_shift_ratio * num; |
128 | std::rotate(filenames.begin(), filenames.begin() + shift, |
129 | filenames.end()); |
130 | } |
131 | |
132 | // Shards files and use one thread to go through each shard. |
133 | const int N = opts_.parallelism; |
134 | std::vector<Shard> shards(N); |
135 | for (int i = 0; i < N; ++i) { |
136 | Shard* shard = &shards[i]; |
137 | shard->index = i; |
138 | for (std::vector<string>::size_type j = i; j < filenames.size(); j += N) { |
139 | shard->filenames.push_back(filenames[j]); |
140 | } |
141 | thread_->Schedule([this, shard]() { ShardLoop(shard); }); |
142 | } |
143 | for (int i = 0; i < N; ++i) { |
144 | shards[i].done.WaitForNotification(); |
145 | s.Update(shards[i].status); |
146 | } |
147 | |
148 | if (num_records_added_in_epoch_ < opts_.bufsize) { |
149 | mutex_lock l(mu_); |
150 | opts_.bufsize = num_records_added_in_epoch_; |
151 | } |
152 | |
153 | if (ShouldFinish(s)) { |
154 | buf_enough_.notify_all(); |
155 | break; |
156 | } |
157 | |
158 | // Starts the next epoch once all buffered records are consumed. |
159 | { |
160 | mutex_lock l(mu_); |
161 | epoch_end_ = true; |
162 | if (BufEnough()) { |
163 | buf_enough_.notify_all(); |
164 | } |
165 | while (!BufEmpty()) { |
166 | buf_empty_.wait(l); |
167 | } |
168 | epoch_end_ = false; |
169 | } |
170 | } |
171 | main_loop_done_.Notify(); |
172 | } |
173 | |
174 | bool RecordYielder::Add(std::vector<string>* values) { |
175 | mutex_lock l(mu_); |
176 | while (!BufNotFull()) { |
177 | buf_not_full_.wait(l); |
178 | } |
179 | while (BufNotFull() && !values->empty()) { |
180 | // Adds values->back(). Swaps its position with another random |
181 | // element. |
182 | auto index = rnd_() % (buf_.size() + 1); |
183 | if (index == buf_.size()) { |
184 | buf_.push_back(std::move(values->back())); |
185 | } else { |
186 | buf_.push_back(std::move(buf_[index])); |
187 | buf_[index] = std::move(values->back()); |
188 | } |
189 | values->pop_back(); |
190 | num_records_added_in_epoch_++; |
191 | } |
192 | if (BufEnough()) { |
193 | buf_enough_.notify_all(); |
194 | } |
195 | return stop_; |
196 | } |
197 | |
198 | void RecordYielder::ShardLoop(Shard* shard) { |
199 | std::vector<string> values; |
200 | const int64_t kRecords = 16; |
201 | for (const string& filename : shard->filenames) { |
202 | std::unique_ptr<RandomAccessFile> file; |
203 | if (ShouldFinish(OkStatus())) break; |
204 | Status s = Env::Default()->NewRandomAccessFile(filename, &file); |
205 | if (!s.ok()) { |
206 | shard->status = errors::InvalidArgument("Can't open " , filename); |
207 | break; |
208 | } |
209 | io::RecordReaderOptions options = |
210 | io::RecordReaderOptions::CreateRecordReaderOptions( |
211 | opts_.compression_type); |
212 | io::RecordReader rdr(file.get(), options); |
213 | uint64 offset = 0; |
214 | tstring record; |
215 | while (true) { |
216 | Status s = rdr.ReadRecord(&offset, &record); |
217 | if (s.ok()) { |
218 | values.emplace_back(std::move(record)); |
219 | if (values.size() >= kRecords && Add(&values)) { |
220 | shard->status = errors::Aborted("stopped" ); |
221 | break; |
222 | } |
223 | } else if (errors::IsOutOfRange(s)) { |
224 | break; |
225 | } else { |
226 | shard->status = s; |
227 | break; |
228 | } |
229 | } |
230 | } |
231 | // Adds the remaining values of this shard to buf_. |
232 | while (!values.empty()) { |
233 | Add(&values); |
234 | } |
235 | shard->done.Notify(); |
236 | } |
237 | |
238 | } // namespace tensorflow |
239 | |