1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/reader_base.h"
17
18#include "tensorflow/core/framework/reader_base.pb.h"
19#include "tensorflow/core/framework/types.h"
20#include "tensorflow/core/lib/core/coding.h"
21#include "tensorflow/core/lib/core/errors.h"
22#include "tensorflow/core/lib/core/notification.h"
23#include "tensorflow/core/lib/core/stringpiece.h"
24#include "tensorflow/core/lib/strings/str_util.h"
25#include "tensorflow/core/lib/strings/strcat.h"
26
27namespace tensorflow {
28
29// ReaderBase ------------------------------------------------------
30
31ReaderBase::ReaderBase(const string& name) : name_(name) {}
32
33int64_t ReaderBase::NumRecordsProduced() {
34 mutex_lock lock(mu_);
35 return num_records_produced_;
36}
37
38int64_t ReaderBase::NumWorkUnitsCompleted() {
39 mutex_lock lock(mu_);
40 return work_finished_;
41}
42
43Status ReaderBase::Reset() {
44 mutex_lock lock(mu_);
45 return ResetLocked();
46}
47
48Status ReaderBase::ResetLocked() {
49 work_started_ = 0;
50 work_finished_ = 0;
51 num_records_produced_ = 0;
52 work_.clear();
53 return OkStatus();
54}
55
56Status ReaderBase::SerializeState(tstring* state) {
57 mutex_lock lock(mu_);
58 return SerializeStateLocked(state);
59}
60
61Status ReaderBase::SerializeStateLocked(tstring* state) {
62 return errors::Unimplemented("Reader SerializeState");
63}
64
65Status ReaderBase::RestoreState(const tstring& state) {
66 mutex_lock lock(mu_);
67 Status status = RestoreStateLocked(state);
68 if (!status.ok()) {
69 ResetLocked().IgnoreError();
70 }
71 return status;
72}
73
74Status ReaderBase::RestoreStateLocked(const tstring& state) {
75 return errors::Unimplemented("Reader RestoreState");
76}
77
78int64_t ReaderBase::ReadUpTo(const int64_t num_records, QueueInterface* queue,
79 std::vector<tstring>* keys,
80 std::vector<tstring>* values,
81 OpKernelContext* context) {
82 mutex_lock lock(mu_);
83 int64_t records_produced_this_call = 0;
84 while (true) {
85 // Records produced by this iteration of the ReadUpToLocked call.
86 int64_t num_records_produced = 0;
87 int64_t remaining = num_records - records_produced_this_call;
88 if (remaining == 0) {
89 return records_produced_this_call;
90 }
91 if (!work_in_progress()) {
92 work_ = GetNextWorkLocked(queue, context);
93 if (!context->status().ok()) {
94 return records_produced_this_call;
95 }
96 Status status = OnWorkStartedLocked();
97 if (status.ok()) {
98 work_started_++;
99 } else {
100 context->SetStatus(status);
101 return records_produced_this_call;
102 }
103 }
104 bool at_end = false;
105
106 Status status =
107 ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end);
108 // This call so far.
109 records_produced_this_call += num_records_produced;
110
111 // In total, over the lifetime of the ReaderBase.
112 num_records_produced_ += num_records_produced;
113
114 if (!at_end && status.ok() && num_records_produced == 0) {
115 status = errors::Internal(
116 "ReadManyLocked() for ", name(),
117 " must set *at_end=true, *num_produced > 0 or return an error.");
118 context->SetStatus(status);
119 return records_produced_this_call;
120 }
121 if (status.ok() && at_end) {
122 status = OnWorkFinishedLocked();
123 work_finished_ = work_started_;
124 if (records_produced_this_call > 0) {
125 return records_produced_this_call;
126 }
127 }
128 if (!status.ok()) {
129 context->SetStatus(status);
130 return records_produced_this_call;
131 }
132 }
133}
134
135// Default implementation just reads one record at a time.
136Status ReaderBase::ReadUpToLocked(int64_t num_records,
137 std::vector<tstring>* keys,
138 std::vector<tstring>* values,
139 int64_t* num_read, bool* at_end) {
140 bool produced = false;
141 tstring key;
142 tstring value;
143 Status status = ReadLocked(&key, &value, &produced, at_end);
144 if (produced) {
145 keys->push_back(std::move(key));
146 values->push_back(std::move(value));
147 *num_read = 1;
148 } else {
149 *num_read = 0;
150 }
151 return status;
152}
153
154void ReaderBase::Read(QueueInterface* queue, tstring* key, tstring* value,
155 OpKernelContext* context) {
156 mutex_lock lock(mu_);
157 while (true) {
158 if (!work_in_progress()) {
159 work_ = GetNextWorkLocked(queue, context);
160 if (!context->status().ok()) {
161 return;
162 }
163 Status status = OnWorkStartedLocked();
164 if (status.ok()) {
165 work_started_++;
166 } else {
167 context->SetStatus(status);
168 return;
169 }
170 }
171
172 bool produced = false;
173 bool at_end = false;
174 Status status = ReadLocked(key, value, &produced, &at_end);
175
176 if (!at_end && status.ok() && !produced) {
177 status = errors::Internal(
178 "ReadLocked() for ", name(),
179 " must set *at_end=true, *produced=true, or return an error.");
180 }
181 if (!status.ok() && produced) {
182 status = errors::Internal("ReadLocked() for ", name(),
183 " set *produced=true *and* returned an error: ",
184 status.error_message());
185 }
186 if (status.ok() && at_end) {
187 status = OnWorkFinishedLocked();
188 work_finished_ = work_started_;
189 }
190 if (!status.ok()) {
191 context->SetStatus(status);
192 return;
193 }
194 if (produced) {
195 ++num_records_produced_;
196 return;
197 }
198 }
199}
200
201string ReaderBase::GetNextWorkLocked(QueueInterface* queue,
202 OpKernelContext* context) const {
203 string work;
204 Notification n;
205 queue->TryDequeue(
206 context, [context, &n, &work](const QueueInterface::Tuple& tuple) {
207 if (context->status().ok()) {
208 if (tuple.size() != 1) {
209 context->SetStatus(
210 errors::InvalidArgument("Expected single component queue"));
211 } else if (tuple[0].dtype() != DT_STRING) {
212 context->SetStatus(errors::InvalidArgument(
213 "Expected queue with single string component"));
214 } else if (tuple[0].NumElements() != 1) {
215 context->SetStatus(errors::InvalidArgument(
216 "Expected to dequeue a one-element string tensor"));
217 } else {
218 work = tuple[0].flat<tstring>()(0);
219 }
220 }
221 n.Notify();
222 });
223 n.WaitForNotification();
224 return work;
225}
226
227void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
228 state->Clear();
229 state->set_work_started(work_started_);
230 state->set_work_finished(work_finished_);
231 state->set_num_records_produced(num_records_produced_);
232 state->set_current_work(work_.data(), work_.size());
233}
234
235tstring ReaderBase::KeyName(const tstring& key) const {
236 return strings::StrCat(current_work(), ":", key);
237}
238
239Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) {
240 work_started_ = state.work_started();
241 work_finished_ = state.work_finished();
242 num_records_produced_ = state.num_records_produced();
243 work_ = state.current_work();
244 if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) {
245#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
246 const string debug_string = "<debug state not available>";
247#else
248 const string debug_string = state.DebugString();
249#endif
250 return errors::InvalidArgument(
251 "Unexpected negative value when restoring in ", name(), ": ",
252 debug_string);
253 }
254 if (work_started_ > work_finished_) {
255#if defined(__ANDROID__) || (__EMSCRIPTEN__)
256 const string debug_string = "<debug state not available>";
257#else
258 const string debug_string = state.DebugString();
259#endif
260 return errors::InvalidArgument(
261 "Inconsistent work started vs. finished when restoring in ", name(),
262 ": ", debug_string);
263 }
264 return OkStatus();
265}
266
267} // namespace tensorflow
268