1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/reader_base.h" |
17 | #include "tensorflow/core/framework/reader_op_kernel.h" |
18 | #include "tensorflow/core/lib/core/errors.h" |
19 | |
20 | #include <sys/stat.h> |
21 | #include "lmdb.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | #define MDB_CHECK(val) CHECK_EQ(val, MDB_SUCCESS) << mdb_strerror(val) |
26 | |
27 | class LMDBReader : public ReaderBase { |
28 | public: |
29 | LMDBReader(const string& node_name, Env* /*unused*/) |
30 | : ReaderBase(strings::StrCat("LMDBReader '" , node_name, "'" )), |
31 | mdb_env_(nullptr), |
32 | mdb_dbi_(0), |
33 | mdb_txn_(nullptr), |
34 | mdb_cursor_(nullptr) {} |
35 | |
36 | Status OnWorkStartedLocked() override { |
37 | MDB_CHECK(mdb_env_create(&mdb_env_)); |
38 | int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; |
39 | |
40 | // Check if the LMDB filename is actually a file instead of a directory. |
41 | // If so, set appropriate flags so we can open it. |
42 | struct stat source_stat; |
43 | if (stat(current_work().c_str(), &source_stat) == 0 && |
44 | (source_stat.st_mode & S_IFREG)) { |
45 | flags |= MDB_NOSUBDIR; |
46 | } |
47 | |
48 | MDB_CHECK(mdb_env_open(mdb_env_, current_work().c_str(), flags, 0664)); |
49 | MDB_CHECK(mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_)); |
50 | MDB_CHECK(mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_)); |
51 | |
52 | return OkStatus(); |
53 | } |
54 | |
55 | Status OnWorkFinishedLocked() override { |
56 | if (mdb_env_ != nullptr) { |
57 | if (mdb_cursor_) { |
58 | mdb_cursor_close(mdb_cursor_); |
59 | mdb_cursor_ = nullptr; |
60 | } |
61 | mdb_dbi_close(mdb_env_, mdb_dbi_); |
62 | mdb_txn_abort(mdb_txn_); |
63 | mdb_env_close(mdb_env_); |
64 | mdb_txn_ = nullptr; |
65 | mdb_dbi_ = 0; |
66 | mdb_env_ = nullptr; |
67 | } |
68 | return OkStatus(); |
69 | } |
70 | |
71 | Status ReadLocked(tstring* key, tstring* value, bool* produced, |
72 | bool* at_end) override { |
73 | if (mdb_cursor_ == nullptr) { |
74 | MDB_CHECK(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_)); |
75 | if (Seek(MDB_FIRST) == false) { |
76 | *at_end = true; |
77 | return OkStatus(); |
78 | } |
79 | } else { |
80 | if (Seek(MDB_NEXT) == false) { |
81 | *at_end = true; |
82 | return OkStatus(); |
83 | } |
84 | } |
85 | *key = |
86 | tstring(static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size); |
87 | *value = tstring(static_cast<const char*>(mdb_value_.mv_data), |
88 | mdb_value_.mv_size); |
89 | *produced = true; |
90 | return OkStatus(); |
91 | } |
92 | |
93 | Status ResetLocked() override { |
94 | CHECK_EQ(Seek(MDB_FIRST), true); |
95 | return ReaderBase::ResetLocked(); |
96 | } |
97 | |
98 | private: |
99 | bool Seek(MDB_cursor_op op) { |
100 | CHECK_NOTNULL(mdb_cursor_); |
101 | int mdb_status = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, op); |
102 | if (mdb_status == MDB_NOTFOUND) { |
103 | return false; |
104 | } else { |
105 | MDB_CHECK(mdb_status); |
106 | return true; |
107 | } |
108 | } |
109 | |
110 | MDB_env* mdb_env_; |
111 | MDB_dbi mdb_dbi_; |
112 | |
113 | MDB_txn* mdb_txn_; |
114 | MDB_cursor* mdb_cursor_; |
115 | MDB_val mdb_key_, mdb_value_; |
116 | }; |
117 | |
118 | class LMDBReaderOp : public ReaderOpKernel { |
119 | public: |
120 | explicit LMDBReaderOp(OpKernelConstruction* context) |
121 | : ReaderOpKernel(context) { |
122 | Env* env = context->env(); |
123 | SetReaderFactory([this, env]() { return new LMDBReader(name(), env); }); |
124 | } |
125 | }; |
126 | |
127 | REGISTER_KERNEL_BUILDER(Name("LMDBReader" ).Device(DEVICE_CPU), LMDBReaderOp); |
128 | |
129 | } // namespace tensorflow |
130 | |