1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/reader_base.h"
17#include "tensorflow/core/framework/reader_op_kernel.h"
18#include "tensorflow/core/lib/core/errors.h"
19
20#include <sys/stat.h>
21#include "lmdb.h"
22
23namespace tensorflow {
24
25#define MDB_CHECK(val) CHECK_EQ(val, MDB_SUCCESS) << mdb_strerror(val)
26
27class LMDBReader : public ReaderBase {
28 public:
29 LMDBReader(const string& node_name, Env* /*unused*/)
30 : ReaderBase(strings::StrCat("LMDBReader '", node_name, "'")),
31 mdb_env_(nullptr),
32 mdb_dbi_(0),
33 mdb_txn_(nullptr),
34 mdb_cursor_(nullptr) {}
35
36 Status OnWorkStartedLocked() override {
37 MDB_CHECK(mdb_env_create(&mdb_env_));
38 int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK;
39
40 // Check if the LMDB filename is actually a file instead of a directory.
41 // If so, set appropriate flags so we can open it.
42 struct stat source_stat;
43 if (stat(current_work().c_str(), &source_stat) == 0 &&
44 (source_stat.st_mode & S_IFREG)) {
45 flags |= MDB_NOSUBDIR;
46 }
47
48 MDB_CHECK(mdb_env_open(mdb_env_, current_work().c_str(), flags, 0664));
49 MDB_CHECK(mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_));
50 MDB_CHECK(mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_));
51
52 return OkStatus();
53 }
54
55 Status OnWorkFinishedLocked() override {
56 if (mdb_env_ != nullptr) {
57 if (mdb_cursor_) {
58 mdb_cursor_close(mdb_cursor_);
59 mdb_cursor_ = nullptr;
60 }
61 mdb_dbi_close(mdb_env_, mdb_dbi_);
62 mdb_txn_abort(mdb_txn_);
63 mdb_env_close(mdb_env_);
64 mdb_txn_ = nullptr;
65 mdb_dbi_ = 0;
66 mdb_env_ = nullptr;
67 }
68 return OkStatus();
69 }
70
71 Status ReadLocked(tstring* key, tstring* value, bool* produced,
72 bool* at_end) override {
73 if (mdb_cursor_ == nullptr) {
74 MDB_CHECK(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_));
75 if (Seek(MDB_FIRST) == false) {
76 *at_end = true;
77 return OkStatus();
78 }
79 } else {
80 if (Seek(MDB_NEXT) == false) {
81 *at_end = true;
82 return OkStatus();
83 }
84 }
85 *key =
86 tstring(static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
87 *value = tstring(static_cast<const char*>(mdb_value_.mv_data),
88 mdb_value_.mv_size);
89 *produced = true;
90 return OkStatus();
91 }
92
93 Status ResetLocked() override {
94 CHECK_EQ(Seek(MDB_FIRST), true);
95 return ReaderBase::ResetLocked();
96 }
97
98 private:
99 bool Seek(MDB_cursor_op op) {
100 CHECK_NOTNULL(mdb_cursor_);
101 int mdb_status = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, op);
102 if (mdb_status == MDB_NOTFOUND) {
103 return false;
104 } else {
105 MDB_CHECK(mdb_status);
106 return true;
107 }
108 }
109
110 MDB_env* mdb_env_;
111 MDB_dbi mdb_dbi_;
112
113 MDB_txn* mdb_txn_;
114 MDB_cursor* mdb_cursor_;
115 MDB_val mdb_key_, mdb_value_;
116};
117
118class LMDBReaderOp : public ReaderOpKernel {
119 public:
120 explicit LMDBReaderOp(OpKernelConstruction* context)
121 : ReaderOpKernel(context) {
122 Env* env = context->env();
123 SetReaderFactory([this, env]() { return new LMDBReader(name(), env); });
124 }
125};
126
127REGISTER_KERNEL_BUILDER(Name("LMDBReader").Device(DEVICE_CPU), LMDBReaderOp);
128
129} // namespace tensorflow
130