1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/reader_base.h" |
17 | |
18 | #include "tensorflow/core/framework/reader_base.pb.h" |
19 | #include "tensorflow/core/framework/types.h" |
20 | #include "tensorflow/core/lib/core/coding.h" |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/lib/core/notification.h" |
23 | #include "tensorflow/core/lib/core/stringpiece.h" |
24 | #include "tensorflow/core/lib/strings/str_util.h" |
25 | #include "tensorflow/core/lib/strings/strcat.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // ReaderBase ------------------------------------------------------ |
30 | |
31 | ReaderBase::ReaderBase(const string& name) : name_(name) {} |
32 | |
33 | int64_t ReaderBase::NumRecordsProduced() { |
34 | mutex_lock lock(mu_); |
35 | return num_records_produced_; |
36 | } |
37 | |
38 | int64_t ReaderBase::NumWorkUnitsCompleted() { |
39 | mutex_lock lock(mu_); |
40 | return work_finished_; |
41 | } |
42 | |
43 | Status ReaderBase::Reset() { |
44 | mutex_lock lock(mu_); |
45 | return ResetLocked(); |
46 | } |
47 | |
48 | Status ReaderBase::ResetLocked() { |
49 | work_started_ = 0; |
50 | work_finished_ = 0; |
51 | num_records_produced_ = 0; |
52 | work_.clear(); |
53 | return OkStatus(); |
54 | } |
55 | |
56 | Status ReaderBase::SerializeState(tstring* state) { |
57 | mutex_lock lock(mu_); |
58 | return SerializeStateLocked(state); |
59 | } |
60 | |
61 | Status ReaderBase::SerializeStateLocked(tstring* state) { |
62 | return errors::Unimplemented("Reader SerializeState" ); |
63 | } |
64 | |
65 | Status ReaderBase::RestoreState(const tstring& state) { |
66 | mutex_lock lock(mu_); |
67 | Status status = RestoreStateLocked(state); |
68 | if (!status.ok()) { |
69 | ResetLocked().IgnoreError(); |
70 | } |
71 | return status; |
72 | } |
73 | |
74 | Status ReaderBase::RestoreStateLocked(const tstring& state) { |
75 | return errors::Unimplemented("Reader RestoreState" ); |
76 | } |
77 | |
78 | int64_t ReaderBase::ReadUpTo(const int64_t num_records, QueueInterface* queue, |
79 | std::vector<tstring>* keys, |
80 | std::vector<tstring>* values, |
81 | OpKernelContext* context) { |
82 | mutex_lock lock(mu_); |
83 | int64_t records_produced_this_call = 0; |
84 | while (true) { |
85 | // Records produced by this iteration of the ReadUpToLocked call. |
86 | int64_t num_records_produced = 0; |
87 | int64_t remaining = num_records - records_produced_this_call; |
88 | if (remaining == 0) { |
89 | return records_produced_this_call; |
90 | } |
91 | if (!work_in_progress()) { |
92 | work_ = GetNextWorkLocked(queue, context); |
93 | if (!context->status().ok()) { |
94 | return records_produced_this_call; |
95 | } |
96 | Status status = OnWorkStartedLocked(); |
97 | if (status.ok()) { |
98 | work_started_++; |
99 | } else { |
100 | context->SetStatus(status); |
101 | return records_produced_this_call; |
102 | } |
103 | } |
104 | bool at_end = false; |
105 | |
106 | Status status = |
107 | ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end); |
108 | // This call so far. |
109 | records_produced_this_call += num_records_produced; |
110 | |
111 | // In total, over the lifetime of the ReaderBase. |
112 | num_records_produced_ += num_records_produced; |
113 | |
114 | if (!at_end && status.ok() && num_records_produced == 0) { |
115 | status = errors::Internal( |
116 | "ReadManyLocked() for " , name(), |
117 | " must set *at_end=true, *num_produced > 0 or return an error." ); |
118 | context->SetStatus(status); |
119 | return records_produced_this_call; |
120 | } |
121 | if (status.ok() && at_end) { |
122 | status = OnWorkFinishedLocked(); |
123 | work_finished_ = work_started_; |
124 | if (records_produced_this_call > 0) { |
125 | return records_produced_this_call; |
126 | } |
127 | } |
128 | if (!status.ok()) { |
129 | context->SetStatus(status); |
130 | return records_produced_this_call; |
131 | } |
132 | } |
133 | } |
134 | |
135 | // Default implementation just reads one record at a time. |
136 | Status ReaderBase::ReadUpToLocked(int64_t num_records, |
137 | std::vector<tstring>* keys, |
138 | std::vector<tstring>* values, |
139 | int64_t* num_read, bool* at_end) { |
140 | bool produced = false; |
141 | tstring key; |
142 | tstring value; |
143 | Status status = ReadLocked(&key, &value, &produced, at_end); |
144 | if (produced) { |
145 | keys->push_back(std::move(key)); |
146 | values->push_back(std::move(value)); |
147 | *num_read = 1; |
148 | } else { |
149 | *num_read = 0; |
150 | } |
151 | return status; |
152 | } |
153 | |
154 | void ReaderBase::Read(QueueInterface* queue, tstring* key, tstring* value, |
155 | OpKernelContext* context) { |
156 | mutex_lock lock(mu_); |
157 | while (true) { |
158 | if (!work_in_progress()) { |
159 | work_ = GetNextWorkLocked(queue, context); |
160 | if (!context->status().ok()) { |
161 | return; |
162 | } |
163 | Status status = OnWorkStartedLocked(); |
164 | if (status.ok()) { |
165 | work_started_++; |
166 | } else { |
167 | context->SetStatus(status); |
168 | return; |
169 | } |
170 | } |
171 | |
172 | bool produced = false; |
173 | bool at_end = false; |
174 | Status status = ReadLocked(key, value, &produced, &at_end); |
175 | |
176 | if (!at_end && status.ok() && !produced) { |
177 | status = errors::Internal( |
178 | "ReadLocked() for " , name(), |
179 | " must set *at_end=true, *produced=true, or return an error." ); |
180 | } |
181 | if (!status.ok() && produced) { |
182 | status = errors::Internal("ReadLocked() for " , name(), |
183 | " set *produced=true *and* returned an error: " , |
184 | status.error_message()); |
185 | } |
186 | if (status.ok() && at_end) { |
187 | status = OnWorkFinishedLocked(); |
188 | work_finished_ = work_started_; |
189 | } |
190 | if (!status.ok()) { |
191 | context->SetStatus(status); |
192 | return; |
193 | } |
194 | if (produced) { |
195 | ++num_records_produced_; |
196 | return; |
197 | } |
198 | } |
199 | } |
200 | |
201 | string ReaderBase::GetNextWorkLocked(QueueInterface* queue, |
202 | OpKernelContext* context) const { |
203 | string work; |
204 | Notification n; |
205 | queue->TryDequeue( |
206 | context, [context, &n, &work](const QueueInterface::Tuple& tuple) { |
207 | if (context->status().ok()) { |
208 | if (tuple.size() != 1) { |
209 | context->SetStatus( |
210 | errors::InvalidArgument("Expected single component queue" )); |
211 | } else if (tuple[0].dtype() != DT_STRING) { |
212 | context->SetStatus(errors::InvalidArgument( |
213 | "Expected queue with single string component" )); |
214 | } else if (tuple[0].NumElements() != 1) { |
215 | context->SetStatus(errors::InvalidArgument( |
216 | "Expected to dequeue a one-element string tensor" )); |
217 | } else { |
218 | work = tuple[0].flat<tstring>()(0); |
219 | } |
220 | } |
221 | n.Notify(); |
222 | }); |
223 | n.WaitForNotification(); |
224 | return work; |
225 | } |
226 | |
227 | void ReaderBase::SaveBaseState(ReaderBaseState* state) const { |
228 | state->Clear(); |
229 | state->set_work_started(work_started_); |
230 | state->set_work_finished(work_finished_); |
231 | state->set_num_records_produced(num_records_produced_); |
232 | state->set_current_work(work_.data(), work_.size()); |
233 | } |
234 | |
235 | tstring ReaderBase::KeyName(const tstring& key) const { |
236 | return strings::StrCat(current_work(), ":" , key); |
237 | } |
238 | |
239 | Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) { |
240 | work_started_ = state.work_started(); |
241 | work_finished_ = state.work_finished(); |
242 | num_records_produced_ = state.num_records_produced(); |
243 | work_ = state.current_work(); |
244 | if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) { |
245 | #if defined(__ANDROID__) || defined(__EMSCRIPTEN__) |
246 | const string debug_string = "<debug state not available>" ; |
247 | #else |
248 | const string debug_string = state.DebugString(); |
249 | #endif |
250 | return errors::InvalidArgument( |
251 | "Unexpected negative value when restoring in " , name(), ": " , |
252 | debug_string); |
253 | } |
254 | if (work_started_ > work_finished_) { |
255 | #if defined(__ANDROID__) || (__EMSCRIPTEN__) |
256 | const string debug_string = "<debug state not available>" ; |
257 | #else |
258 | const string debug_string = state.DebugString(); |
259 | #endif |
260 | return errors::InvalidArgument( |
261 | "Inconsistent work started vs. finished when restoring in " , name(), |
262 | ": " , debug_string); |
263 | } |
264 | return OkStatus(); |
265 | } |
266 | |
267 | } // namespace tensorflow |
268 | |