1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/io_ops.cc.
17
18#include <memory>
19#include "tensorflow/core/framework/reader_base.h"
20#include "tensorflow/core/framework/reader_op_kernel.h"
21#include "tensorflow/core/lib/core/errors.h"
22#include "tensorflow/core/lib/io/buffered_inputstream.h"
23#include "tensorflow/core/lib/io/random_inputstream.h"
24#include "tensorflow/core/lib/io/zlib_compression_options.h"
25#include "tensorflow/core/lib/io/zlib_inputstream.h"
26#include "tensorflow/core/lib/strings/strcat.h"
27#include "tensorflow/core/platform/env.h"
28
29namespace tensorflow {
30
31// In the constructor hop_bytes_ is set to record_bytes_ if it was 0,
32// so that we will always "hop" after each read (except first).
33class FixedLengthRecordReader : public ReaderBase {
34 public:
35 FixedLengthRecordReader(const string& node_name, int64_t header_bytes,
36 int64_t record_bytes, int64_t footer_bytes,
37 int64_t hop_bytes, const string& encoding, Env* env)
38 : ReaderBase(
39 strings::StrCat("FixedLengthRecordReader '", node_name, "'")),
40 header_bytes_(header_bytes),
41 record_bytes_(record_bytes),
42 footer_bytes_(footer_bytes),
43 hop_bytes_(hop_bytes == 0 ? record_bytes : hop_bytes),
44 env_(env),
45 record_number_(0),
46 encoding_(encoding) {}
47
48 // On success:
49 // * buffered_inputstream_ != nullptr,
50 // * buffered_inputstream_->Tell() == header_bytes_
51 Status OnWorkStartedLocked() override {
52 record_number_ = 0;
53
54 lookahead_cache_.clear();
55
56 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_));
57 if (encoding_ == "ZLIB" || encoding_ == "GZIP") {
58 const io::ZlibCompressionOptions zlib_options =
59 encoding_ == "ZLIB" ? io::ZlibCompressionOptions::DEFAULT()
60 : io::ZlibCompressionOptions::GZIP();
61 file_stream_.reset(new io::RandomAccessInputStream(file_.get()));
62 buffered_inputstream_.reset(new io::ZlibInputStream(
63 file_stream_.get(), static_cast<size_t>(kBufferSize),
64 static_cast<size_t>(kBufferSize), zlib_options));
65 } else {
66 buffered_inputstream_.reset(
67 new io::BufferedInputStream(file_.get(), kBufferSize));
68 }
69 // header_bytes_ is always skipped.
70 TF_RETURN_IF_ERROR(buffered_inputstream_->SkipNBytes(header_bytes_));
71
72 return OkStatus();
73 }
74
75 Status OnWorkFinishedLocked() override {
76 buffered_inputstream_.reset(nullptr);
77 return OkStatus();
78 }
79
80 Status ReadLocked(tstring* key, tstring* value, bool* produced,
81 bool* at_end) override {
82 // We will always "hop" the hop_bytes_ except the first record
83 // where record_number_ == 0
84 if (record_number_ != 0) {
85 if (hop_bytes_ <= lookahead_cache_.size()) {
86 // If hop_bytes_ is smaller than the cached data we skip the
87 // hop_bytes_ from the cache.
88 lookahead_cache_ = lookahead_cache_.substr(hop_bytes_);
89 } else {
90 // If hop_bytes_ is larger than the cached data, we clean up
91 // the cache, then skip hop_bytes_ - cache_size from the file
92 // as the cache_size has been skipped through cache.
93 int64_t cache_size = lookahead_cache_.size();
94 lookahead_cache_.clear();
95 Status s = buffered_inputstream_->SkipNBytes(hop_bytes_ - cache_size);
96 if (!s.ok()) {
97 if (!errors::IsOutOfRange(s)) {
98 return s;
99 }
100 *at_end = true;
101 return OkStatus();
102 }
103 }
104 }
105
106 // Fill up lookahead_cache_ to record_bytes_ + footer_bytes_
107 int bytes_to_read = record_bytes_ + footer_bytes_ - lookahead_cache_.size();
108 Status s = buffered_inputstream_->ReadNBytes(bytes_to_read, value);
109 if (!s.ok()) {
110 value->clear();
111 if (!errors::IsOutOfRange(s)) {
112 return s;
113 }
114 *at_end = true;
115 return OkStatus();
116 }
117 lookahead_cache_.append(*value, 0, bytes_to_read);
118 value->clear();
119
120 // Copy first record_bytes_ from cache to value
121 *value = lookahead_cache_.substr(0, record_bytes_);
122
123 *key = strings::StrCat(current_work(), ":", record_number_);
124 *produced = true;
125 ++record_number_;
126
127 return OkStatus();
128 }
129
130 Status ResetLocked() override {
131 record_number_ = 0;
132 buffered_inputstream_.reset(nullptr);
133 lookahead_cache_.clear();
134 return ReaderBase::ResetLocked();
135 }
136
137 // TODO(josh11b): Implement serializing and restoring the state.
138
139 private:
140 enum { kBufferSize = 256 << 10 /* 256 kB */ };
141 const int64_t header_bytes_;
142 const int64_t record_bytes_;
143 const int64_t footer_bytes_;
144 const int64_t hop_bytes_;
145 // The purpose of lookahead_cache_ is to allows "one-pass" processing
146 // without revisit previous processed data of the stream. This is needed
147 // because certain compression like zlib does not allow random access
148 // or even obtain the uncompressed stream size before hand.
149 // The max size of the lookahead_cache_ could be
150 // record_bytes_ + footer_bytes_
151 string lookahead_cache_;
152 Env* const env_;
153 int64_t record_number_;
154 string encoding_;
155 // must outlive buffered_inputstream_
156 std::unique_ptr<RandomAccessFile> file_;
157 // must outlive buffered_inputstream_
158 std::unique_ptr<io::RandomAccessInputStream> file_stream_;
159 std::unique_ptr<io::InputStreamInterface> buffered_inputstream_;
160};
161
162class FixedLengthRecordReaderOp : public ReaderOpKernel {
163 public:
164 explicit FixedLengthRecordReaderOp(OpKernelConstruction* context)
165 : ReaderOpKernel(context) {
166 int64_t header_bytes = -1, record_bytes = -1, footer_bytes = -1,
167 hop_bytes = -1;
168 OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes));
169 OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes));
170 OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes));
171 OP_REQUIRES_OK(context, context->GetAttr("hop_bytes", &hop_bytes));
172 OP_REQUIRES(context, header_bytes >= 0,
173 errors::InvalidArgument("header_bytes must be >= 0 not ",
174 header_bytes));
175 OP_REQUIRES(context, record_bytes >= 0,
176 errors::InvalidArgument("record_bytes must be >= 0 not ",
177 record_bytes));
178 OP_REQUIRES(context, footer_bytes >= 0,
179 errors::InvalidArgument("footer_bytes must be >= 0 not ",
180 footer_bytes));
181 OP_REQUIRES(
182 context, hop_bytes >= 0,
183 errors::InvalidArgument("hop_bytes must be >= 0 not ", hop_bytes));
184 Env* env = context->env();
185 string encoding;
186 OP_REQUIRES_OK(context, context->GetAttr("encoding", &encoding));
187 SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, hop_bytes,
188 encoding, env]() {
189 return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
190 footer_bytes, hop_bytes, encoding,
191 env);
192 });
193 }
194};
195
196REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReader").Device(DEVICE_CPU),
197 FixedLengthRecordReaderOp);
198REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReaderV2").Device(DEVICE_CPU),
199 FixedLengthRecordReaderOp);
200
201} // namespace tensorflow
202