1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/io_ops.cc.
17
18#include <memory>
19#include "tensorflow/core/framework/reader_base.h"
20#include "tensorflow/core/framework/reader_op_kernel.h"
21#include "tensorflow/core/lib/core/errors.h"
22#include "tensorflow/core/lib/io/inputbuffer.h"
23#include "tensorflow/core/lib/strings/strcat.h"
24#include "tensorflow/core/platform/env.h"
25
26namespace tensorflow {
27
28class TextLineReader : public ReaderBase {
29 public:
30 TextLineReader(const string& node_name, int skip_header_lines, Env* env)
31 : ReaderBase(strings::StrCat("TextLineReader '", node_name, "'")),
32 skip_header_lines_(skip_header_lines),
33 env_(env),
34 line_number_(0) {}
35
36 Status OnWorkStartedLocked() override {
37 line_number_ = 0;
38 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_));
39
40 input_buffer_.reset(new io::InputBuffer(file_.get(), kBufferSize));
41 for (; line_number_ < skip_header_lines_; ++line_number_) {
42 string line_contents;
43 Status status = input_buffer_->ReadLine(&line_contents);
44 if (errors::IsOutOfRange(status)) {
45 // We ignore an end of file error when skipping header lines.
46 // We will end up skipping this file.
47 return OkStatus();
48 }
49 TF_RETURN_IF_ERROR(status);
50 }
51 return OkStatus();
52 }
53
54 Status OnWorkFinishedLocked() override {
55 input_buffer_.reset(nullptr);
56 return OkStatus();
57 }
58
59 Status ReadLocked(tstring* key, tstring* value, bool* produced,
60 bool* at_end) override {
61 Status status = input_buffer_->ReadLine(value);
62 ++line_number_;
63 if (status.ok()) {
64 *key = strings::StrCat(current_work(), ":", line_number_);
65 *produced = true;
66 return status;
67 }
68 if (errors::IsOutOfRange(status)) { // End of file, advance to the next.
69 *at_end = true;
70 return OkStatus();
71 } else { // Some other reading error
72 return status;
73 }
74 }
75
76 Status ResetLocked() override {
77 line_number_ = 0;
78 input_buffer_.reset(nullptr);
79 return ReaderBase::ResetLocked();
80 }
81
82 // TODO(josh11b): Implement serializing and restoring the state. Need
83 // to create TextLineReaderState proto to store ReaderBaseState,
84 // line_number_, and input_buffer_->Tell().
85
86 private:
87 enum { kBufferSize = 256 << 10 /* 256 kB */ };
88 const int skip_header_lines_;
89 Env* const env_;
90 int64_t line_number_;
91 std::unique_ptr<RandomAccessFile> file_; // must outlive input_buffer_
92 std::unique_ptr<io::InputBuffer> input_buffer_;
93};
94
95class TextLineReaderOp : public ReaderOpKernel {
96 public:
97 explicit TextLineReaderOp(OpKernelConstruction* context)
98 : ReaderOpKernel(context) {
99 int skip_header_lines = -1;
100 OP_REQUIRES_OK(context,
101 context->GetAttr("skip_header_lines", &skip_header_lines));
102 OP_REQUIRES(context, skip_header_lines >= 0,
103 errors::InvalidArgument("skip_header_lines must be >= 0 not ",
104 skip_header_lines));
105 Env* env = context->env();
106 SetReaderFactory([this, skip_header_lines, env]() {
107 return new TextLineReader(name(), skip_header_lines, env);
108 });
109 }
110};
111
112REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
113 TextLineReaderOp);
114REGISTER_KERNEL_BUILDER(Name("TextLineReaderV2").Device(DEVICE_CPU),
115 TextLineReaderOp);
116
117} // namespace tensorflow
118