1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_ |
18 | |
19 | #include <memory> |
20 | #include <string> |
21 | #include "tensorflow/core/framework/queue_interface.h" |
22 | #include "tensorflow/core/framework/reader_interface.h" |
23 | #include "tensorflow/core/lib/core/stringpiece.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | class ReaderBaseState; |
28 | |
29 | // Default implementation of ReaderInterface. |
30 | class ReaderBase : public ReaderInterface { |
31 | public: |
32 | // name: For use in error messages, should mention both the name of |
33 | // the op and the node. |
34 | explicit ReaderBase(const string& name); |
35 | |
36 | // Note that methods with names ending in "Locked" are called while |
37 | // the ReaderBase's mutex is held. |
38 | |
39 | // Implement this function in descendants ----------------------------------- |
40 | |
41 | // Produce the next key/value pair from the current work item. |
42 | // This is called "Locked" since it is executed under a mutex |
43 | // that serializes all Reader calls. |
44 | // Usage: |
45 | // a) If a record was successfully produced, set *produced = true, |
46 | // and fill in *key and *value. |
47 | // b) If no more records will be produced for this work item, set |
48 | // *at_end = true. |
49 | // c) If a record was produced, but no more will be produced, you |
50 | // may either do both (a) and (b), or do (a) in this call and do (b) in |
51 | // the next call to ReadLocked(). |
52 | // d) If there was an error producing (e.g. an error reading the file, |
53 | // data corruption), return a non-OK() status. ReadLocked may be |
54 | // called again if the user reruns this part of the graph. |
55 | virtual Status ReadLocked(tstring* key, tstring* value, bool* produced, |
56 | bool* at_end) = 0; |
57 | |
58 | // Descendants may optionally implement these ------------------------------- |
59 | |
60 | // Produce up to num_records next key/value pairs from the current |
61 | // work item, in the same manner of ReadLocked. |
62 | virtual Status ReadUpToLocked(int64_t num_records, std::vector<tstring>* keys, |
63 | std::vector<tstring>* values, int64_t* num_read, |
64 | bool* at_end); |
65 | |
66 | // Called when work starts / finishes. |
67 | virtual Status OnWorkStartedLocked() { return OkStatus(); } |
68 | virtual Status OnWorkFinishedLocked() { return OkStatus(); } |
69 | |
70 | // Called to reset the Reader to a newly constructed state. |
71 | virtual Status ResetLocked(); |
72 | |
73 | // Default implementation generates an Unimplemented error. |
74 | // See the protected helper methods below. |
75 | virtual Status SerializeStateLocked(tstring* state); |
76 | virtual Status RestoreStateLocked(const tstring& state); |
77 | |
78 | // Accessors ---------------------------------------------------------------- |
79 | |
80 | // Always true during a call to ReadLocked(). |
81 | bool work_in_progress() const { return work_finished_ < work_started_; } |
82 | |
83 | // Returns the name of the current work item (valid if |
84 | // work_in_progress() returns true). May change between calls to |
85 | // ReadLocked(). |
86 | const tstring& current_work() const { return work_; } |
87 | |
88 | // What was passed to the constructor. |
89 | const string& name() const { return name_; } |
90 | |
91 | // Produce the key name (from current_work and the actual key). |
92 | tstring KeyName(const tstring& key) const; |
93 | |
94 | protected: |
95 | // For descendants wishing to implement serialize & restore state. |
96 | |
97 | // Writes ReaderBase state to *state. |
98 | void SaveBaseState(ReaderBaseState* state) const; |
99 | |
100 | // Restores ReaderBase state from state. Assumes state was filled |
101 | // using SaveBaseState() above. |
102 | Status RestoreBaseState(const ReaderBaseState& state); |
103 | |
104 | private: |
105 | // For descendants that wish to obtain the next work item in a different way. |
106 | // For implementing Read(). Dequeues the next work item from |
107 | // *queue, and if successful returns "work" (a string). May block. |
108 | virtual string GetNextWorkLocked(QueueInterface* queue, |
109 | OpKernelContext* context) const; |
110 | |
111 | // Implementations of ReaderInterface methods. These ensure thread-safety |
112 | // and call the methods above to do the work. |
113 | void Read(QueueInterface* queue, tstring* key, tstring* value, |
114 | OpKernelContext* context) override; |
115 | |
116 | // Produces up to num_records. |
117 | // In this implementation all the records come from the same work unit. |
118 | int64_t ReadUpTo(const int64_t num_records, QueueInterface* queue, |
119 | std::vector<tstring>* keys, std::vector<tstring>* value, |
120 | OpKernelContext* context) override; |
121 | |
122 | Status Reset() override; |
123 | int64_t NumRecordsProduced() override; |
124 | int64_t NumWorkUnitsCompleted() override; |
125 | Status SerializeState(tstring* state) override; |
126 | Status RestoreState(const tstring& state) override; |
127 | |
128 | mutable mutex mu_; |
129 | const string name_; |
130 | int64_t work_started_ = 0; |
131 | int64_t work_finished_ = 0; |
132 | int64_t num_records_produced_ = 0; |
133 | tstring work_; |
134 | }; |
135 | |
136 | } // namespace tensorflow |
137 | |
138 | #endif // TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_ |
139 | |