1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/io_ops.cc. |
17 | |
18 | #include <memory> |
19 | #include "tensorflow/core/framework/reader_base.h" |
20 | #include "tensorflow/core/framework/reader_op_kernel.h" |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/lib/io/inputbuffer.h" |
23 | #include "tensorflow/core/lib/strings/strcat.h" |
24 | #include "tensorflow/core/platform/env.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | class TextLineReader : public ReaderBase { |
29 | public: |
30 | TextLineReader(const string& node_name, int , Env* env) |
31 | : ReaderBase(strings::StrCat("TextLineReader '" , node_name, "'" )), |
32 | skip_header_lines_(skip_header_lines), |
33 | env_(env), |
34 | line_number_(0) {} |
35 | |
36 | Status OnWorkStartedLocked() override { |
37 | line_number_ = 0; |
38 | TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_)); |
39 | |
40 | input_buffer_.reset(new io::InputBuffer(file_.get(), kBufferSize)); |
41 | for (; line_number_ < skip_header_lines_; ++line_number_) { |
42 | string line_contents; |
43 | Status status = input_buffer_->ReadLine(&line_contents); |
44 | if (errors::IsOutOfRange(status)) { |
45 | // We ignore an end of file error when skipping header lines. |
46 | // We will end up skipping this file. |
47 | return OkStatus(); |
48 | } |
49 | TF_RETURN_IF_ERROR(status); |
50 | } |
51 | return OkStatus(); |
52 | } |
53 | |
54 | Status OnWorkFinishedLocked() override { |
55 | input_buffer_.reset(nullptr); |
56 | return OkStatus(); |
57 | } |
58 | |
59 | Status ReadLocked(tstring* key, tstring* value, bool* produced, |
60 | bool* at_end) override { |
61 | Status status = input_buffer_->ReadLine(value); |
62 | ++line_number_; |
63 | if (status.ok()) { |
64 | *key = strings::StrCat(current_work(), ":" , line_number_); |
65 | *produced = true; |
66 | return status; |
67 | } |
68 | if (errors::IsOutOfRange(status)) { // End of file, advance to the next. |
69 | *at_end = true; |
70 | return OkStatus(); |
71 | } else { // Some other reading error |
72 | return status; |
73 | } |
74 | } |
75 | |
76 | Status ResetLocked() override { |
77 | line_number_ = 0; |
78 | input_buffer_.reset(nullptr); |
79 | return ReaderBase::ResetLocked(); |
80 | } |
81 | |
82 | // TODO(josh11b): Implement serializing and restoring the state. Need |
83 | // to create TextLineReaderState proto to store ReaderBaseState, |
84 | // line_number_, and input_buffer_->Tell(). |
85 | |
86 | private: |
87 | enum { kBufferSize = 256 << 10 /* 256 kB */ }; |
88 | const int ; |
89 | Env* const env_; |
90 | int64_t line_number_; |
91 | std::unique_ptr<RandomAccessFile> file_; // must outlive input_buffer_ |
92 | std::unique_ptr<io::InputBuffer> input_buffer_; |
93 | }; |
94 | |
95 | class TextLineReaderOp : public ReaderOpKernel { |
96 | public: |
97 | explicit TextLineReaderOp(OpKernelConstruction* context) |
98 | : ReaderOpKernel(context) { |
99 | int = -1; |
100 | OP_REQUIRES_OK(context, |
101 | context->GetAttr("skip_header_lines" , &skip_header_lines)); |
102 | OP_REQUIRES(context, skip_header_lines >= 0, |
103 | errors::InvalidArgument("skip_header_lines must be >= 0 not " , |
104 | skip_header_lines)); |
105 | Env* env = context->env(); |
106 | SetReaderFactory([this, skip_header_lines, env]() { |
107 | return new TextLineReader(name(), skip_header_lines, env); |
108 | }); |
109 | } |
110 | }; |
111 | |
112 | REGISTER_KERNEL_BUILDER(Name("TextLineReader" ).Device(DEVICE_CPU), |
113 | TextLineReaderOp); |
114 | REGISTER_KERNEL_BUILDER(Name("TextLineReaderV2" ).Device(DEVICE_CPU), |
115 | TextLineReaderOp); |
116 | |
117 | } // namespace tensorflow |
118 | |