1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
17#define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
18
19#include <atomic>
20
21#include "tensorflow/core/framework/lookup_interface.h"
22#include "tensorflow/core/platform/macros.h"
23
24namespace tensorflow {
25namespace lookup {
26
27// Base class for lookup tables that require initialization.
28class InitializableLookupTable : public LookupInterface {
29 public:
30 class InitTableIterator;
31 class InitializerSerializer;
32
33 // Performs batch lookups, for every element in the key tensor, Find returns
34 // the corresponding value into the values tensor.
35 // If an element is not present in the table, the given default value is used.
36 //
37 // For tables that require initialization, `Find` is available once the table
38 // is marked as initialized.
39 //
40 // Returns the following statuses:
41 // - OK: when the find finishes successfully.
42 // - FailedPrecondition: if the table is not initialized.
43 // - InvalidArgument: if any of the preconditions on the lookup key or value
44 // fails.
45 // - In addition, other implementations may provide another non-OK status
46 // specific to their failure modes.
47 Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
48 const Tensor& default_value) final;
49
50 // Returns errors::Unimplemented.
51 Status Insert(OpKernelContext* ctx, const Tensor& keys,
52 const Tensor& values) final {
53 return errors::Unimplemented(
54 "Insert not supported by InitializableLookupTable implementations");
55 }
56
57 // Returns errors::Unimplemented.
58 Status Remove(OpKernelContext* ctx, const Tensor& keys) final {
59 return errors::Unimplemented(
60 "Remove not supported by InitializableLookupTable implementations");
61 }
62
63 Status ExportValues(OpKernelContext* context) override {
64 return errors::Unimplemented(
65 "ExportValues not supported by InitializableLookupTable "
66 "implementations");
67 }
68
69 Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
70 const Tensor& values) final;
71
72 TensorShape key_shape() const final { return TensorShape(); }
73
74 TensorShape value_shape() const final { return TensorShape(); }
75
76 // Returns whether the table was initialized and is ready to serve lookups.
77 bool is_initialized() const {
78 return is_initialized_.load(std::memory_order_acquire);
79 }
80
81 // Initializes the table from the given init table iterator.
82 //
83 // Atomically, this operation prepares the table, populates it with the given
84 // iterator, and marks the table as initialized.
85 //
86 // Returns the following statuses:
87 // - OK: when the initialization was successful.
88 // - InvalidArgument: if any of the preconditions on the lookup key or value
89 // fails.
90 // - FailedPrecondition: if the table is already initialized and
91 // fail_if_initialized is set to true.
92 // - In addition, other implementations may provide another non-OK status
93 // specific to their failure modes.
94 Status Initialize(InitTableIterator& iter);
95
96 // Initializes the table from the given init table iterator. `serializer` may
97 // specify how to serialize the table initializer, so that the table can be
98 // serialized using its metadata (as opposed to serializing a handle to the
99 // table).
100 Status Initialize(InitTableIterator& iter,
101 std::unique_ptr<InitializerSerializer> serializer);
102
103 // Basic iterator to initialize lookup tables.
104 // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that
105 // the consumer may insert key-value pairs in batches.
106 //
107 // Then the iterator is exhausted, valid returns false and status returns
108 // Status::OutOfRange.
109 //
110 // This class is Thread-unsafe.
111 class InitTableIterator {
112 public:
113 InitTableIterator() {}
114
115 virtual ~InitTableIterator() {}
116
117 // Prepares the next batch of key and value tensors.
118 virtual void Next() = 0;
119
120 // Returns true if keys and values point to valid tensors.
121 virtual bool Valid() const = 0;
122
123 // Returns a tensor that contains the current batch of 'key' values.
124 virtual const Tensor& keys() const = 0;
125
126 // Returns a tensor that contains the current batch of 'value' values.
127 virtual const Tensor& values() const = 0;
128
129 // Returns an error if one has occurred, otherwise returns Status::OK.
130 virtual Status status() const = 0;
131
132 // Returns the total number of elements that the iterator will produce.
133 // It might return -1 in case of error.
134 virtual int64_t total_size() const = 0;
135
136 private:
137 TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator);
138 };
139
140 InitializableLookupTable* GetInitializableLookupTable() override {
141 return this;
142 }
143
144 // Logic specifying how to represent an initializer as a GraphDef, so that a
145 // lookup table can be serialized using its metadata (as opposed to
146 // serializing the content of the table, or a handle to the table).
147 class InitializerSerializer {
148 public:
149 // A function which builds a graph so that executing `*out` will initialize
150 // `table`.
151 using SerializeFn = std::function<Status(GraphDefBuilder* builder,
152 Node* table, Node** out)>;
153 // A function which performs any necessary cleanup for the serializer.
154 using CleanupFn = std::function<void()>;
155
156 // Wraps serialization logic that requires no cleanup.
157 explicit InitializerSerializer(SerializeFn serialize)
158 : serialize_(std::move(serialize)), cleanup_([] {}) {}
159
160 // Wraps serialization logic along with a cleanup function. `cleanup` will
161 // be run when the serializer is destroyed.
162 explicit InitializerSerializer(SerializeFn serialize, CleanupFn cleanup)
163 : serialize_(std::move(serialize)), cleanup_(std::move(cleanup)) {}
164
165 ~InitializerSerializer() { cleanup_(); }
166
167 // Builds a graph so that executing `*out` will initialize `table`.
168 Status AsGraphDef(GraphDefBuilder* builder, Node* table, Node** out) {
169 return serialize_(builder, table, out);
170 }
171
172 private:
173 SerializeFn serialize_;
174 CleanupFn cleanup_;
175 };
176
177 protected:
178 // Prepares and allocates the underlying data structure to store the given
179 // number of expected elements.
180 virtual Status DoPrepare(size_t expected_num_elements) = 0;
181
182 // Same as DoPrepare() but derived implementations might choose to skip
183 // calling get_expected_num_elements if size is not needed for DoPrepare.
184 virtual Status DoLazyPrepare(
185 std::function<int64_t(void)> get_expected_num_elements) {
186 int64_t expected_num_elements = get_expected_num_elements();
187 if (expected_num_elements < 0) {
188 return errors::FailedPrecondition("Got negative expected_num_elements.");
189 }
190 return DoPrepare(expected_num_elements);
191 }
192
193 // Populates the table in batches given keys and values as tensors into the
194 // underlying data structure.
195 virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0;
196
197 // Performs the batch find operation on the underlying data structure.
198 virtual Status DoFind(const Tensor& keys, Tensor* values,
199 const Tensor& default_value) = 0;
200
201 virtual Status AreEntriesSame(const InitTableIterator& iter, bool* result);
202
203 mutex mu_;
204
205 protected:
206 // When set, provides a mechanism for serializing the table initializer as
207 // GraphDef.
208 std::unique_ptr<InitializerSerializer> initializer_serializer_;
209
210 private:
211 std::atomic<bool> is_initialized_{false};
212};
213
214// Iterator to initialize tables given 'keys' and 'values' tensors.
215//
216// The two tensors are returned in the first iteration. It doesn't loop
217// over each element of the tensor since insertions in the lookup table can
218// process batches.
219class KeyValueTensorIterator
220 : public InitializableLookupTable::InitTableIterator {
221 public:
222 // keys and values are not owned by the iterator.
223 explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
224 : keys_(keys), values_(values), valid_(true), status_(OkStatus()) {
225 TensorShape key_shape = keys_->shape();
226 if (!key_shape.IsSameSize(values_->shape())) {
227 valid_ = false;
228 status_ = errors::InvalidArgument(
229 "keys and values should have the same dimension.",
230 key_shape.DebugString(), " vs ", values_->shape().DebugString());
231 }
232 if (key_shape.num_elements() == 0) {
233 valid_ = false;
234 status_ =
235 errors::InvalidArgument("keys and values cannot be empty tensors.");
236 }
237 }
238
239 bool Valid() const override { return valid_; }
240
241 void Next() override {
242 valid_ = false;
243 status_ = errors::OutOfRange("No more data.");
244 }
245
246 const Tensor& keys() const override { return *keys_; }
247
248 const Tensor& values() const override { return *values_; }
249
250 Status status() const override { return status_; }
251
252 int64_t total_size() const override {
253 return keys_ == nullptr ? -1 : keys_->NumElements();
254 }
255
256 private:
257 TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
258
259 const Tensor* keys_; // Doesn't own it.
260 const Tensor* values_; // Doesn't own it.
261 bool valid_; // true if the iterator points to an existing range.
262 Status status_;
263};
264
265} // namespace lookup
266} // namespace tensorflow
267
268#endif // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
269