1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/record_yielder.h"
17
18#include "tensorflow/core/lib/io/record_reader.h"
19#include "tensorflow/core/lib/strings/str_util.h"
20#include "tensorflow/core/platform/env.h"
21
22namespace tensorflow {
23
24RecordYielder::RecordYielder(OpKernelConstruction* context,
25 const RecordYielder::Options& opts)
26 : opts_(opts),
27 thread_(new thread::ThreadPool(context->env(), ThreadOptions(),
28 "record_yielder", 1 + opts.parallelism,
29 /* low_latency_hint */ false)),
30 epoch_(0),
31 rnd_(opts.seed) {
32 thread_->Schedule([this]() { MainLoop(); });
33}
34
35RecordYielder::~RecordYielder() {
36 {
37 mutex_lock l(mu_);
38 stop_ = true;
39 buf_empty_.notify_all();
40 buf_enough_.notify_all();
41 buf_not_full_.notify_all();
42 }
43 main_loop_done_.WaitForNotification();
44 delete thread_;
45}
46
47Status RecordYielder::YieldOne(tstring* value) {
48 mutex_lock l(mu_);
49 while (!BufEnough() && status_.ok()) {
50 buf_enough_.wait(l);
51 }
52 if (status_.ok()) {
53 bool notify_no_longer_full = !BufNotFull();
54 CHECK(!stop_ && !buf_.empty());
55 *value = std::move(buf_.back());
56 buf_.pop_back();
57 ++num_records_yielded_in_epoch_;
58 // Assumption is that an epoch always has something in the buffer
59 // until it ends. If the input pipeline was slower than the consumers
60 // by a lot this might not be true. Not sure how to handle.
61 if (buf_.empty()) {
62 buf_empty_.notify_all();
63 }
64 if (notify_no_longer_full) {
65 buf_not_full_.notify_all();
66 }
67 }
68 return status_;
69}
70
71struct RecordYielder::Shard {
72 int index; // Shard index.
73 std::vector<tstring> filenames; // File names given to this shard.
74 Notification done; // Notified when this shard is done.
75 Status status; // Shard status.
76};
77
78bool RecordYielder::ShouldFinish(const Status& s) {
79 mutex_lock l(mu_);
80 status_.Update(s);
81 return stop_ || !status_.ok();
82}
83
84static Status MatchFiles(const string& patterns,
85 std::vector<string>* filenames) {
86 for (const auto& file_pattern : str_util::Split(patterns, ',')) {
87 std::vector<string> tmp_filenames;
88 TF_RETURN_IF_ERROR(
89 Env::Default()->GetMatchingPaths(file_pattern, &tmp_filenames));
90 filenames->insert(filenames->end(),
91 std::make_move_iterator(tmp_filenames.begin()),
92 std::make_move_iterator(tmp_filenames.end()));
93 }
94 return OkStatus();
95}
96
97void RecordYielder::MainLoop() {
98 while (true) {
99 ++epoch_;
100 num_records_yielded_in_epoch_ = 0;
101 num_records_added_in_epoch_ = 0;
102
103 // Finds all files.
104 std::vector<string> filenames;
105 Status s = MatchFiles(opts_.file_pattern, &filenames);
106
107 if (filenames.empty()) {
108 s = errors::NotFound("Found no files at ", opts_.file_pattern);
109 if (ShouldFinish(s)) {
110 buf_enough_.notify_all();
111 break;
112 }
113 }
114
115 if (ShouldFinish(s)) break;
116
117 // Shuffles these files according to the epoch # and random seed.
118 std::mt19937_64 shuffle_rnd(
119 Hash64(reinterpret_cast<char*>(&epoch_), sizeof(epoch_), opts_.seed));
120 std::shuffle(filenames.begin(), filenames.end(), shuffle_rnd);
121
122 // Left-shift the filename list.
123 const std::vector<string>::size_type num = filenames.size();
124 int64_t shift;
125 if (0 <= opts_.file_shuffle_shift_ratio &&
126 opts_.file_shuffle_shift_ratio < 1) {
127 shift = opts_.file_shuffle_shift_ratio * num;
128 std::rotate(filenames.begin(), filenames.begin() + shift,
129 filenames.end());
130 }
131
132 // Shards files and use one thread to go through each shard.
133 const int N = opts_.parallelism;
134 std::vector<Shard> shards(N);
135 for (int i = 0; i < N; ++i) {
136 Shard* shard = &shards[i];
137 shard->index = i;
138 for (std::vector<string>::size_type j = i; j < filenames.size(); j += N) {
139 shard->filenames.push_back(filenames[j]);
140 }
141 thread_->Schedule([this, shard]() { ShardLoop(shard); });
142 }
143 for (int i = 0; i < N; ++i) {
144 shards[i].done.WaitForNotification();
145 s.Update(shards[i].status);
146 }
147
148 if (num_records_added_in_epoch_ < opts_.bufsize) {
149 mutex_lock l(mu_);
150 opts_.bufsize = num_records_added_in_epoch_;
151 }
152
153 if (ShouldFinish(s)) {
154 buf_enough_.notify_all();
155 break;
156 }
157
158 // Starts the next epoch once all buffered records are consumed.
159 {
160 mutex_lock l(mu_);
161 epoch_end_ = true;
162 if (BufEnough()) {
163 buf_enough_.notify_all();
164 }
165 while (!BufEmpty()) {
166 buf_empty_.wait(l);
167 }
168 epoch_end_ = false;
169 }
170 }
171 main_loop_done_.Notify();
172}
173
174bool RecordYielder::Add(std::vector<string>* values) {
175 mutex_lock l(mu_);
176 while (!BufNotFull()) {
177 buf_not_full_.wait(l);
178 }
179 while (BufNotFull() && !values->empty()) {
180 // Adds values->back(). Swaps its position with another random
181 // element.
182 auto index = rnd_() % (buf_.size() + 1);
183 if (index == buf_.size()) {
184 buf_.push_back(std::move(values->back()));
185 } else {
186 buf_.push_back(std::move(buf_[index]));
187 buf_[index] = std::move(values->back());
188 }
189 values->pop_back();
190 num_records_added_in_epoch_++;
191 }
192 if (BufEnough()) {
193 buf_enough_.notify_all();
194 }
195 return stop_;
196}
197
198void RecordYielder::ShardLoop(Shard* shard) {
199 std::vector<string> values;
200 const int64_t kRecords = 16;
201 for (const string& filename : shard->filenames) {
202 std::unique_ptr<RandomAccessFile> file;
203 if (ShouldFinish(OkStatus())) break;
204 Status s = Env::Default()->NewRandomAccessFile(filename, &file);
205 if (!s.ok()) {
206 shard->status = errors::InvalidArgument("Can't open ", filename);
207 break;
208 }
209 io::RecordReaderOptions options =
210 io::RecordReaderOptions::CreateRecordReaderOptions(
211 opts_.compression_type);
212 io::RecordReader rdr(file.get(), options);
213 uint64 offset = 0;
214 tstring record;
215 while (true) {
216 Status s = rdr.ReadRecord(&offset, &record);
217 if (s.ok()) {
218 values.emplace_back(std::move(record));
219 if (values.size() >= kRecords && Add(&values)) {
220 shard->status = errors::Aborted("stopped");
221 break;
222 }
223 } else if (errors::IsOutOfRange(s)) {
224 break;
225 } else {
226 shard->status = s;
227 break;
228 }
229 }
230 }
231 // Adds the remaining values of this shard to buf_.
232 while (!values.empty()) {
233 Add(&values);
234 }
235 shard->done.Notify();
236}
237
238} // namespace tensorflow
239