1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/io_ops.cc. |
17 | |
18 | #include <memory> |
19 | #include "tensorflow/core/framework/reader_base.h" |
20 | #include "tensorflow/core/framework/reader_op_kernel.h" |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/lib/io/record_reader.h" |
23 | #include "tensorflow/core/lib/strings/strcat.h" |
24 | #include "tensorflow/core/platform/env.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | class TFRecordReader : public ReaderBase { |
29 | public: |
30 | TFRecordReader(const string& node_name, const string& compression_type, |
31 | Env* env) |
32 | : ReaderBase(strings::StrCat("TFRecordReader '" , node_name, "'" )), |
33 | env_(env), |
34 | offset_(0), |
35 | compression_type_(compression_type) {} |
36 | |
37 | Status OnWorkStartedLocked() override { |
38 | offset_ = 0; |
39 | TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_)); |
40 | |
41 | io::RecordReaderOptions options = |
42 | io::RecordReaderOptions::CreateRecordReaderOptions(compression_type_); |
43 | reader_.reset(new io::RecordReader(file_.get(), options)); |
44 | return OkStatus(); |
45 | } |
46 | |
47 | Status OnWorkFinishedLocked() override { |
48 | reader_.reset(nullptr); |
49 | file_.reset(nullptr); |
50 | return OkStatus(); |
51 | } |
52 | |
53 | Status ReadLocked(tstring* key, tstring* value, bool* produced, |
54 | bool* at_end) override { |
55 | *key = strings::StrCat(current_work(), ":" , offset_); |
56 | Status status = reader_->ReadRecord(&offset_, value); |
57 | if (errors::IsOutOfRange(status)) { |
58 | *at_end = true; |
59 | return OkStatus(); |
60 | } |
61 | if (!status.ok()) return status; |
62 | *produced = true; |
63 | return OkStatus(); |
64 | } |
65 | |
66 | Status ResetLocked() override { |
67 | offset_ = 0; |
68 | reader_.reset(nullptr); |
69 | file_.reset(nullptr); |
70 | return ReaderBase::ResetLocked(); |
71 | } |
72 | |
73 | // TODO(josh11b): Implement serializing and restoring the state. |
74 | |
75 | private: |
76 | Env* const env_; |
77 | uint64 offset_; |
78 | std::unique_ptr<RandomAccessFile> file_; |
79 | std::unique_ptr<io::RecordReader> reader_; |
80 | string compression_type_ = "" ; |
81 | }; |
82 | |
83 | class TFRecordReaderOp : public ReaderOpKernel { |
84 | public: |
85 | explicit TFRecordReaderOp(OpKernelConstruction* context) |
86 | : ReaderOpKernel(context) { |
87 | Env* env = context->env(); |
88 | |
89 | string compression_type; |
90 | OP_REQUIRES_OK(context, |
91 | context->GetAttr("compression_type" , &compression_type)); |
92 | |
93 | SetReaderFactory([this, compression_type, env]() { |
94 | return new TFRecordReader(name(), compression_type, env); |
95 | }); |
96 | } |
97 | }; |
98 | |
99 | REGISTER_KERNEL_BUILDER(Name("TFRecordReader" ).Device(DEVICE_CPU), |
100 | TFRecordReaderOp); |
101 | REGISTER_KERNEL_BUILDER(Name("TFRecordReaderV2" ).Device(DEVICE_CPU), |
102 | TFRecordReaderOp); |
103 | |
104 | } // namespace tensorflow |
105 | |