1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ |
18 | |
19 | #include <atomic> |
20 | |
21 | #include "tensorflow/core/framework/lookup_interface.h" |
22 | #include "tensorflow/core/platform/macros.h" |
23 | |
24 | namespace tensorflow { |
25 | namespace lookup { |
26 | |
27 | // Base class for lookup tables that require initialization. |
28 | class InitializableLookupTable : public LookupInterface { |
29 | public: |
30 | class InitTableIterator; |
31 | class InitializerSerializer; |
32 | |
33 | // Performs batch lookups, for every element in the key tensor, Find returns |
34 | // the corresponding value into the values tensor. |
35 | // If an element is not present in the table, the given default value is used. |
36 | // |
37 | // For tables that require initialization, `Find` is available once the table |
38 | // is marked as initialized. |
39 | // |
40 | // Returns the following statuses: |
41 | // - OK: when the find finishes successfully. |
42 | // - FailedPrecondition: if the table is not initialized. |
43 | // - InvalidArgument: if any of the preconditions on the lookup key or value |
44 | // fails. |
45 | // - In addition, other implementations may provide another non-OK status |
46 | // specific to their failure modes. |
47 | Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, |
48 | const Tensor& default_value) final; |
49 | |
50 | // Returns errors::Unimplemented. |
51 | Status Insert(OpKernelContext* ctx, const Tensor& keys, |
52 | const Tensor& values) final { |
53 | return errors::Unimplemented( |
54 | "Insert not supported by InitializableLookupTable implementations" ); |
55 | } |
56 | |
57 | // Returns errors::Unimplemented. |
58 | Status Remove(OpKernelContext* ctx, const Tensor& keys) final { |
59 | return errors::Unimplemented( |
60 | "Remove not supported by InitializableLookupTable implementations" ); |
61 | } |
62 | |
63 | Status ExportValues(OpKernelContext* context) override { |
64 | return errors::Unimplemented( |
65 | "ExportValues not supported by InitializableLookupTable " |
66 | "implementations" ); |
67 | } |
68 | |
69 | Status ImportValues(OpKernelContext* ctx, const Tensor& keys, |
70 | const Tensor& values) final; |
71 | |
72 | TensorShape key_shape() const final { return TensorShape(); } |
73 | |
74 | TensorShape value_shape() const final { return TensorShape(); } |
75 | |
76 | // Returns whether the table was initialized and is ready to serve lookups. |
77 | bool is_initialized() const { |
78 | return is_initialized_.load(std::memory_order_acquire); |
79 | } |
80 | |
81 | // Initializes the table from the given init table iterator. |
82 | // |
83 | // Atomically, this operation prepares the table, populates it with the given |
84 | // iterator, and marks the table as initialized. |
85 | // |
86 | // Returns the following statuses: |
87 | // - OK: when the initialization was successful. |
88 | // - InvalidArgument: if any of the preconditions on the lookup key or value |
89 | // fails. |
90 | // - FailedPrecondition: if the table is already initialized and |
91 | // fail_if_initialized is set to true. |
92 | // - In addition, other implementations may provide another non-OK status |
93 | // specific to their failure modes. |
94 | Status Initialize(InitTableIterator& iter); |
95 | |
96 | // Initializes the table from the given init table iterator. `serializer` may |
97 | // specify how to serialize the table initializer, so that the table can be |
98 | // serialized using its metadata (as opposed to serializing a handle to the |
99 | // table). |
100 | Status Initialize(InitTableIterator& iter, |
101 | std::unique_ptr<InitializerSerializer> serializer); |
102 | |
103 | // Basic iterator to initialize lookup tables. |
104 | // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that |
105 | // the consumer may insert key-value pairs in batches. |
106 | // |
107 | // Then the iterator is exhausted, valid returns false and status returns |
108 | // Status::OutOfRange. |
109 | // |
110 | // This class is Thread-unsafe. |
111 | class InitTableIterator { |
112 | public: |
113 | InitTableIterator() {} |
114 | |
115 | virtual ~InitTableIterator() {} |
116 | |
117 | // Prepares the next batch of key and value tensors. |
118 | virtual void Next() = 0; |
119 | |
120 | // Returns true if keys and values point to valid tensors. |
121 | virtual bool Valid() const = 0; |
122 | |
123 | // Returns a tensor that contains the current batch of 'key' values. |
124 | virtual const Tensor& keys() const = 0; |
125 | |
126 | // Returns a tensor that contains the current batch of 'value' values. |
127 | virtual const Tensor& values() const = 0; |
128 | |
129 | // Returns an error if one has occurred, otherwise returns Status::OK. |
130 | virtual Status status() const = 0; |
131 | |
132 | // Returns the total number of elements that the iterator will produce. |
133 | // It might return -1 in case of error. |
134 | virtual int64_t total_size() const = 0; |
135 | |
136 | private: |
137 | TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator); |
138 | }; |
139 | |
140 | InitializableLookupTable* GetInitializableLookupTable() override { |
141 | return this; |
142 | } |
143 | |
144 | // Logic specifying how to represent an initializer as a GraphDef, so that a |
145 | // lookup table can be serialized using its metadata (as opposed to |
146 | // serializing the content of the table, or a handle to the table). |
147 | class InitializerSerializer { |
148 | public: |
149 | // A function which builds a graph so that executing `*out` will initialize |
150 | // `table`. |
151 | using SerializeFn = std::function<Status(GraphDefBuilder* builder, |
152 | Node* table, Node** out)>; |
153 | // A function which performs any necessary cleanup for the serializer. |
154 | using CleanupFn = std::function<void()>; |
155 | |
156 | // Wraps serialization logic that requires no cleanup. |
157 | explicit InitializerSerializer(SerializeFn serialize) |
158 | : serialize_(std::move(serialize)), cleanup_([] {}) {} |
159 | |
160 | // Wraps serialization logic along with a cleanup function. `cleanup` will |
161 | // be run when the serializer is destroyed. |
162 | explicit InitializerSerializer(SerializeFn serialize, CleanupFn cleanup) |
163 | : serialize_(std::move(serialize)), cleanup_(std::move(cleanup)) {} |
164 | |
165 | ~InitializerSerializer() { cleanup_(); } |
166 | |
167 | // Builds a graph so that executing `*out` will initialize `table`. |
168 | Status AsGraphDef(GraphDefBuilder* builder, Node* table, Node** out) { |
169 | return serialize_(builder, table, out); |
170 | } |
171 | |
172 | private: |
173 | SerializeFn serialize_; |
174 | CleanupFn cleanup_; |
175 | }; |
176 | |
177 | protected: |
178 | // Prepares and allocates the underlying data structure to store the given |
179 | // number of expected elements. |
180 | virtual Status DoPrepare(size_t expected_num_elements) = 0; |
181 | |
182 | // Same as DoPrepare() but derived implementations might choose to skip |
183 | // calling get_expected_num_elements if size is not needed for DoPrepare. |
184 | virtual Status DoLazyPrepare( |
185 | std::function<int64_t(void)> get_expected_num_elements) { |
186 | int64_t expected_num_elements = get_expected_num_elements(); |
187 | if (expected_num_elements < 0) { |
188 | return errors::FailedPrecondition("Got negative expected_num_elements." ); |
189 | } |
190 | return DoPrepare(expected_num_elements); |
191 | } |
192 | |
193 | // Populates the table in batches given keys and values as tensors into the |
194 | // underlying data structure. |
195 | virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0; |
196 | |
197 | // Performs the batch find operation on the underlying data structure. |
198 | virtual Status DoFind(const Tensor& keys, Tensor* values, |
199 | const Tensor& default_value) = 0; |
200 | |
201 | virtual Status AreEntriesSame(const InitTableIterator& iter, bool* result); |
202 | |
203 | mutex mu_; |
204 | |
205 | protected: |
206 | // When set, provides a mechanism for serializing the table initializer as |
207 | // GraphDef. |
208 | std::unique_ptr<InitializerSerializer> initializer_serializer_; |
209 | |
210 | private: |
211 | std::atomic<bool> is_initialized_{false}; |
212 | }; |
213 | |
214 | // Iterator to initialize tables given 'keys' and 'values' tensors. |
215 | // |
216 | // The two tensors are returned in the first iteration. It doesn't loop |
217 | // over each element of the tensor since insertions in the lookup table can |
218 | // process batches. |
219 | class KeyValueTensorIterator |
220 | : public InitializableLookupTable::InitTableIterator { |
221 | public: |
222 | // keys and values are not owned by the iterator. |
223 | explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values) |
224 | : keys_(keys), values_(values), valid_(true), status_(OkStatus()) { |
225 | TensorShape key_shape = keys_->shape(); |
226 | if (!key_shape.IsSameSize(values_->shape())) { |
227 | valid_ = false; |
228 | status_ = errors::InvalidArgument( |
229 | "keys and values should have the same dimension." , |
230 | key_shape.DebugString(), " vs " , values_->shape().DebugString()); |
231 | } |
232 | if (key_shape.num_elements() == 0) { |
233 | valid_ = false; |
234 | status_ = |
235 | errors::InvalidArgument("keys and values cannot be empty tensors." ); |
236 | } |
237 | } |
238 | |
239 | bool Valid() const override { return valid_; } |
240 | |
241 | void Next() override { |
242 | valid_ = false; |
243 | status_ = errors::OutOfRange("No more data." ); |
244 | } |
245 | |
246 | const Tensor& keys() const override { return *keys_; } |
247 | |
248 | const Tensor& values() const override { return *values_; } |
249 | |
250 | Status status() const override { return status_; } |
251 | |
252 | int64_t total_size() const override { |
253 | return keys_ == nullptr ? -1 : keys_->NumElements(); |
254 | } |
255 | |
256 | private: |
257 | TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator); |
258 | |
259 | const Tensor* keys_; // Doesn't own it. |
260 | const Tensor* values_; // Doesn't own it. |
261 | bool valid_; // true if the iterator points to an existing range. |
262 | Status status_; |
263 | }; |
264 | |
265 | } // namespace lookup |
266 | } // namespace tensorflow |
267 | |
268 | #endif // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ |
269 | |