1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/io_ops.cc. |
17 | |
18 | #include <memory> |
19 | #include "tensorflow/core/framework/reader_base.h" |
20 | #include "tensorflow/core/framework/reader_op_kernel.h" |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/lib/io/buffered_inputstream.h" |
23 | #include "tensorflow/core/lib/io/random_inputstream.h" |
24 | #include "tensorflow/core/lib/io/zlib_compression_options.h" |
25 | #include "tensorflow/core/lib/io/zlib_inputstream.h" |
26 | #include "tensorflow/core/lib/strings/strcat.h" |
27 | #include "tensorflow/core/platform/env.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | // In the constructor hop_bytes_ is set to record_bytes_ if it was 0, |
32 | // so that we will always "hop" after each read (except first). |
33 | class FixedLengthRecordReader : public ReaderBase { |
34 | public: |
35 | FixedLengthRecordReader(const string& node_name, int64_t , |
36 | int64_t record_bytes, int64_t , |
37 | int64_t hop_bytes, const string& encoding, Env* env) |
38 | : ReaderBase( |
39 | strings::StrCat("FixedLengthRecordReader '" , node_name, "'" )), |
40 | header_bytes_(header_bytes), |
41 | record_bytes_(record_bytes), |
42 | footer_bytes_(footer_bytes), |
43 | hop_bytes_(hop_bytes == 0 ? record_bytes : hop_bytes), |
44 | env_(env), |
45 | record_number_(0), |
46 | encoding_(encoding) {} |
47 | |
48 | // On success: |
49 | // * buffered_inputstream_ != nullptr, |
50 | // * buffered_inputstream_->Tell() == header_bytes_ |
51 | Status OnWorkStartedLocked() override { |
52 | record_number_ = 0; |
53 | |
54 | lookahead_cache_.clear(); |
55 | |
56 | TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_)); |
57 | if (encoding_ == "ZLIB" || encoding_ == "GZIP" ) { |
58 | const io::ZlibCompressionOptions zlib_options = |
59 | encoding_ == "ZLIB" ? io::ZlibCompressionOptions::DEFAULT() |
60 | : io::ZlibCompressionOptions::GZIP(); |
61 | file_stream_.reset(new io::RandomAccessInputStream(file_.get())); |
62 | buffered_inputstream_.reset(new io::ZlibInputStream( |
63 | file_stream_.get(), static_cast<size_t>(kBufferSize), |
64 | static_cast<size_t>(kBufferSize), zlib_options)); |
65 | } else { |
66 | buffered_inputstream_.reset( |
67 | new io::BufferedInputStream(file_.get(), kBufferSize)); |
68 | } |
69 | // header_bytes_ is always skipped. |
70 | TF_RETURN_IF_ERROR(buffered_inputstream_->SkipNBytes(header_bytes_)); |
71 | |
72 | return OkStatus(); |
73 | } |
74 | |
75 | Status OnWorkFinishedLocked() override { |
76 | buffered_inputstream_.reset(nullptr); |
77 | return OkStatus(); |
78 | } |
79 | |
80 | Status ReadLocked(tstring* key, tstring* value, bool* produced, |
81 | bool* at_end) override { |
82 | // We will always "hop" the hop_bytes_ except the first record |
83 | // where record_number_ == 0 |
84 | if (record_number_ != 0) { |
85 | if (hop_bytes_ <= lookahead_cache_.size()) { |
86 | // If hop_bytes_ is smaller than the cached data we skip the |
87 | // hop_bytes_ from the cache. |
88 | lookahead_cache_ = lookahead_cache_.substr(hop_bytes_); |
89 | } else { |
90 | // If hop_bytes_ is larger than the cached data, we clean up |
91 | // the cache, then skip hop_bytes_ - cache_size from the file |
92 | // as the cache_size has been skipped through cache. |
93 | int64_t cache_size = lookahead_cache_.size(); |
94 | lookahead_cache_.clear(); |
95 | Status s = buffered_inputstream_->SkipNBytes(hop_bytes_ - cache_size); |
96 | if (!s.ok()) { |
97 | if (!errors::IsOutOfRange(s)) { |
98 | return s; |
99 | } |
100 | *at_end = true; |
101 | return OkStatus(); |
102 | } |
103 | } |
104 | } |
105 | |
106 | // Fill up lookahead_cache_ to record_bytes_ + footer_bytes_ |
107 | int bytes_to_read = record_bytes_ + footer_bytes_ - lookahead_cache_.size(); |
108 | Status s = buffered_inputstream_->ReadNBytes(bytes_to_read, value); |
109 | if (!s.ok()) { |
110 | value->clear(); |
111 | if (!errors::IsOutOfRange(s)) { |
112 | return s; |
113 | } |
114 | *at_end = true; |
115 | return OkStatus(); |
116 | } |
117 | lookahead_cache_.append(*value, 0, bytes_to_read); |
118 | value->clear(); |
119 | |
120 | // Copy first record_bytes_ from cache to value |
121 | *value = lookahead_cache_.substr(0, record_bytes_); |
122 | |
123 | *key = strings::StrCat(current_work(), ":" , record_number_); |
124 | *produced = true; |
125 | ++record_number_; |
126 | |
127 | return OkStatus(); |
128 | } |
129 | |
130 | Status ResetLocked() override { |
131 | record_number_ = 0; |
132 | buffered_inputstream_.reset(nullptr); |
133 | lookahead_cache_.clear(); |
134 | return ReaderBase::ResetLocked(); |
135 | } |
136 | |
137 | // TODO(josh11b): Implement serializing and restoring the state. |
138 | |
139 | private: |
140 | enum { kBufferSize = 256 << 10 /* 256 kB */ }; |
141 | const int64_t ; |
142 | const int64_t record_bytes_; |
143 | const int64_t ; |
144 | const int64_t hop_bytes_; |
145 | // The purpose of lookahead_cache_ is to allows "one-pass" processing |
146 | // without revisit previous processed data of the stream. This is needed |
147 | // because certain compression like zlib does not allow random access |
148 | // or even obtain the uncompressed stream size before hand. |
149 | // The max size of the lookahead_cache_ could be |
150 | // record_bytes_ + footer_bytes_ |
151 | string lookahead_cache_; |
152 | Env* const env_; |
153 | int64_t record_number_; |
154 | string encoding_; |
155 | // must outlive buffered_inputstream_ |
156 | std::unique_ptr<RandomAccessFile> file_; |
157 | // must outlive buffered_inputstream_ |
158 | std::unique_ptr<io::RandomAccessInputStream> file_stream_; |
159 | std::unique_ptr<io::InputStreamInterface> buffered_inputstream_; |
160 | }; |
161 | |
162 | class FixedLengthRecordReaderOp : public ReaderOpKernel { |
163 | public: |
164 | explicit FixedLengthRecordReaderOp(OpKernelConstruction* context) |
165 | : ReaderOpKernel(context) { |
166 | int64_t = -1, record_bytes = -1, = -1, |
167 | hop_bytes = -1; |
168 | OP_REQUIRES_OK(context, context->GetAttr("header_bytes" , &header_bytes)); |
169 | OP_REQUIRES_OK(context, context->GetAttr("record_bytes" , &record_bytes)); |
170 | OP_REQUIRES_OK(context, context->GetAttr("footer_bytes" , &footer_bytes)); |
171 | OP_REQUIRES_OK(context, context->GetAttr("hop_bytes" , &hop_bytes)); |
172 | OP_REQUIRES(context, header_bytes >= 0, |
173 | errors::InvalidArgument("header_bytes must be >= 0 not " , |
174 | header_bytes)); |
175 | OP_REQUIRES(context, record_bytes >= 0, |
176 | errors::InvalidArgument("record_bytes must be >= 0 not " , |
177 | record_bytes)); |
178 | OP_REQUIRES(context, footer_bytes >= 0, |
179 | errors::InvalidArgument("footer_bytes must be >= 0 not " , |
180 | footer_bytes)); |
181 | OP_REQUIRES( |
182 | context, hop_bytes >= 0, |
183 | errors::InvalidArgument("hop_bytes must be >= 0 not " , hop_bytes)); |
184 | Env* env = context->env(); |
185 | string encoding; |
186 | OP_REQUIRES_OK(context, context->GetAttr("encoding" , &encoding)); |
187 | SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, hop_bytes, |
188 | encoding, env]() { |
189 | return new FixedLengthRecordReader(name(), header_bytes, record_bytes, |
190 | footer_bytes, hop_bytes, encoding, |
191 | env); |
192 | }); |
193 | } |
194 | }; |
195 | |
196 | REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReader" ).Device(DEVICE_CPU), |
197 | FixedLengthRecordReaderOp); |
198 | REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReaderV2" ).Device(DEVICE_CPU), |
199 | FixedLengthRecordReaderOp); |
200 | |
201 | } // namespace tensorflow |
202 | |