1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/lookup_table_op.h" |
17 | #define EIGEN_USE_THREADS |
18 | |
19 | #include <string> |
20 | #include <type_traits> |
21 | #include <utility> |
22 | |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/framework/types.h" |
25 | #include "tensorflow/core/framework/variant.h" |
26 | #include "tensorflow/core/kernels/initializable_lookup_table.h" |
27 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
28 | #include "tensorflow/core/lib/hash/hash.h" |
29 | #include "tensorflow/core/platform/random.h" |
30 | |
31 | namespace tensorflow { |
32 | namespace lookup { |
33 | |
34 | std::string UniqueNodeName(const std::string& base) { |
35 | static std::atomic<int64_t> counter(0); |
36 | return strings::StrCat(base, "/" , counter.fetch_add(1), "/" , random::New64()); |
37 | } |
38 | |
39 | // Lookup table that wraps an unordered_map, where the key and value data type |
40 | // is specified. Each individual value must be a scalar. If vector values are |
41 | // required, use MutableHashTableOfTensors. |
42 | // |
43 | // This table is mutable and thread safe - Insert can be called at any time. |
44 | // |
45 | // Sample use case: |
46 | // |
47 | // MutableHashTableOfScalars<int64, int64> table; // int64 -> int64. |
48 | // // Populate the table, elements could be added in one or multiple calls. |
49 | // table.Insert(key_tensor, value_tensor); // Populate the table. |
50 | // |
51 | // table.Find(in_t, &out_t, default_t) |
52 | // |
53 | template <class K, class V> |
54 | class MutableHashTableOfScalars final : public LookupInterface { |
55 | public: |
56 | MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {} |
57 | |
58 | size_t size() const override { |
59 | tf_shared_lock l(mu_); |
60 | return table_.size(); |
61 | } |
62 | |
63 | Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, |
64 | const Tensor& default_value) override { |
65 | const auto key_values = key.flat<K>(); |
66 | auto value_values = value->flat<V>(); |
67 | const auto default_flat = default_value.flat<V>(); |
68 | |
69 | int64_t total = value_values.size(); |
70 | int64_t default_total = default_flat.size(); |
71 | bool is_full_size_default = (total == default_total); |
72 | |
73 | tf_shared_lock l(mu_); |
74 | for (int64_t i = 0; i < key_values.size(); ++i) { |
75 | // is_full_size_default is true: |
76 | // Each key has an independent default value, key_values(i) |
77 | // corresponding uses default_flat(i) as its default value. |
78 | // |
79 | // is_full_size_default is false: |
80 | // All keys will share the default_flat(0) as default value. |
81 | value_values(i) = gtl::FindWithDefault( |
82 | table_, SubtleMustCopyIfIntegral(key_values(i)), |
83 | is_full_size_default ? default_flat(i) : default_flat(0)); |
84 | } |
85 | |
86 | return OkStatus(); |
87 | } |
88 | |
89 | Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { |
90 | const auto key_values = keys.flat<K>(); |
91 | const auto value_values = values.flat<V>(); |
92 | |
93 | mutex_lock l(mu_); |
94 | if (clear) { |
95 | table_.clear(); |
96 | } |
97 | for (int64_t i = 0; i < key_values.size(); ++i) { |
98 | gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)), |
99 | SubtleMustCopyIfIntegral(value_values(i))); |
100 | } |
101 | return OkStatus(); |
102 | } |
103 | |
104 | Status Insert(OpKernelContext* ctx, const Tensor& keys, |
105 | const Tensor& values) override { |
106 | return DoInsert(false, keys, values); |
107 | } |
108 | |
109 | Status Remove(OpKernelContext* ctx, const Tensor& keys) override { |
110 | const auto key_values = keys.flat<K>(); |
111 | |
112 | mutex_lock l(mu_); |
113 | for (int64_t i = 0; i < key_values.size(); ++i) { |
114 | table_.erase(SubtleMustCopyIfIntegral(key_values(i))); |
115 | } |
116 | return OkStatus(); |
117 | } |
118 | |
119 | Status ImportValues(OpKernelContext* ctx, const Tensor& keys, |
120 | const Tensor& values) override { |
121 | return DoInsert(true, keys, values); |
122 | } |
123 | |
124 | Status ExportValues(OpKernelContext* ctx) override { |
125 | tf_shared_lock l(mu_); |
126 | int64_t size = table_.size(); |
127 | |
128 | Tensor* keys; |
129 | Tensor* values; |
130 | TF_RETURN_IF_ERROR( |
131 | ctx->allocate_output("keys" , TensorShape({size}), &keys)); |
132 | TF_RETURN_IF_ERROR( |
133 | ctx->allocate_output("values" , TensorShape({size}), &values)); |
134 | ExportKeysAndValues(keys, values); |
135 | return OkStatus(); |
136 | } |
137 | |
138 | DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } |
139 | |
140 | DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } |
141 | |
142 | TensorShape key_shape() const final { return TensorShape(); } |
143 | |
144 | TensorShape value_shape() const override { return TensorShape(); } |
145 | |
146 | int64_t MemoryUsed() const override { |
147 | int64_t ret = 0; |
148 | tf_shared_lock l(mu_); |
149 | for (unsigned i = 0; i < table_.bucket_count(); ++i) { |
150 | size_t bucket_size = table_.bucket_size(i); |
151 | if (bucket_size == 0) { |
152 | ret++; |
153 | } else { |
154 | ret += bucket_size; |
155 | } |
156 | } |
157 | return sizeof(MutableHashTableOfScalars) + ret; |
158 | } |
159 | |
160 | Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { |
161 | tf_shared_lock l(mu_); |
162 | int64_t size = table_.size(); |
163 | Tensor keys(key_dtype(), TensorShape({size})); |
164 | Tensor values(value_dtype(), TensorShape({size})); |
165 | ExportKeysAndValues(&keys, &values); |
166 | |
167 | // We set use_node_name_sharing with a unique node name so that the resource |
168 | // can outlive the MutableHashTableV2 kernel. This means that the lifetime |
169 | // of the resource will be tied to the lifetime of the resource manager it |
170 | // is created in. |
171 | // TODO(b/181695913): Provide a mechanism for deleting this resource |
172 | // earlier when appropriate. |
173 | Node* table = ops::SourceOp( |
174 | "MutableHashTableV2" , |
175 | builder->opts() |
176 | .WithName(UniqueNodeName("MutableHashTableFromGraphDef" )) |
177 | .WithAttr("use_node_name_sharing" , true) |
178 | .WithAttr("key_dtype" , key_dtype()) |
179 | .WithAttr("value_dtype" , value_dtype())); |
180 | Node* keys_node = ops::SourceOp( |
181 | "Const" , |
182 | builder->opts().WithAttr("dtype" , key_dtype()).WithAttr("value" , keys)); |
183 | Node* values_node = |
184 | ops::SourceOp("Const" , builder->opts() |
185 | .WithAttr("dtype" , value_dtype()) |
186 | .WithAttr("value" , values)); |
187 | Node* import_table = |
188 | ops::TernaryOp("LookupTableImportV2" , table, keys_node, values_node, |
189 | builder->opts() |
190 | .WithAttr("Tin" , key_dtype()) |
191 | .WithAttr("Tout" , value_dtype())); |
192 | *out = ops::UnaryOp("Identity" , table, |
193 | builder->opts().WithControlInput(import_table)); |
194 | return OkStatus(); |
195 | } |
196 | |
197 | private: |
198 | // Writes all keys and values into `keys` and `values`. `keys` and `values` |
199 | // must point to tensors of size `table_.size()`. |
200 | void ExportKeysAndValues(Tensor* keys, Tensor* values) const |
201 | TF_SHARED_LOCKS_REQUIRED(mu_) { |
202 | auto keys_data = keys->flat<K>(); |
203 | auto values_data = values->flat<V>(); |
204 | int64_t i = 0; |
205 | for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { |
206 | keys_data(i) = it->first; |
207 | values_data(i) = it->second; |
208 | } |
209 | } |
210 | |
211 | mutable mutex mu_; |
212 | std::unordered_map<K, V> table_ TF_GUARDED_BY(mu_); |
213 | }; |
214 | |
215 | // Lookup table that wraps an unordered_map. Behaves identical to |
216 | // MutableHashTableOfScalars except that each value must be a vector. |
217 | template <class K, class V> |
218 | class MutableHashTableOfTensors final : public LookupInterface { |
219 | public: |
220 | MutableHashTableOfTensors(OpKernelContext* ctx, OpKernel* kernel) { |
221 | OP_REQUIRES_OK(ctx, |
222 | GetNodeAttr(kernel->def(), "value_shape" , &value_shape_)); |
223 | OP_REQUIRES( |
224 | ctx, TensorShapeUtils::IsVector(value_shape_), |
225 | errors::InvalidArgument("Default value must be a vector, got shape " , |
226 | value_shape_.DebugString())); |
227 | } |
228 | |
229 | size_t size() const override { |
230 | tf_shared_lock l(mu_); |
231 | return table_.size(); |
232 | } |
233 | |
234 | Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, |
235 | const Tensor& default_value) override { |
236 | const auto default_flat = default_value.flat_inner_dims<V, 2>(); |
237 | const auto key_values = key.flat<K>(); |
238 | auto value_values = value->flat_inner_dims<V, 2>(); |
239 | int64_t value_dim = value_shape_.dim_size(0); |
240 | |
241 | int64_t total = value_values.size(); |
242 | int64_t default_total = default_flat.size(); |
243 | bool is_full_size_default = (total == default_total); |
244 | |
245 | tf_shared_lock l(mu_); |
246 | for (int64_t i = 0; i < key_values.size(); ++i) { |
247 | ValueArray* value_vec = |
248 | gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i))); |
249 | if (value_vec != nullptr) { |
250 | for (int64_t j = 0; j < value_dim; j++) { |
251 | value_values(i, j) = value_vec->at(j); |
252 | } |
253 | } else { |
254 | // is_full_size_default is true: |
255 | // Each key has an independent default value, key_values(i) |
256 | // corresponding uses default_flat(i) as its default value. |
257 | // |
258 | // is_full_size_default is false: |
259 | // All keys will share the default_flat(0) as default value. |
260 | for (int64_t j = 0; j < value_dim; j++) { |
261 | value_values(i, j) = |
262 | is_full_size_default ? default_flat(i, j) : default_flat(0, j); |
263 | } |
264 | } |
265 | } |
266 | |
267 | return OkStatus(); |
268 | } |
269 | |
270 | Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { |
271 | const auto key_values = keys.flat<K>(); |
272 | const auto value_values = values.flat_inner_dims<V, 2>(); |
273 | int64_t value_dim = value_shape_.dim_size(0); |
274 | |
275 | mutex_lock l(mu_); |
276 | if (clear) { |
277 | table_.clear(); |
278 | } |
279 | for (int64_t i = 0; i < key_values.size(); ++i) { |
280 | ValueArray value_vec; |
281 | for (int64_t j = 0; j < value_dim; j++) { |
282 | V value = value_values(i, j); |
283 | value_vec.push_back(value); |
284 | } |
285 | gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)), |
286 | value_vec); |
287 | } |
288 | return OkStatus(); |
289 | } |
290 | |
291 | Status Insert(OpKernelContext* ctx, const Tensor& keys, |
292 | const Tensor& values) override { |
293 | return DoInsert(false, keys, values); |
294 | } |
295 | |
296 | Status Remove(OpKernelContext* ctx, const Tensor& keys) override { |
297 | const auto key_values = keys.flat<K>(); |
298 | |
299 | mutex_lock l(mu_); |
300 | for (int64_t i = 0; i < key_values.size(); ++i) { |
301 | table_.erase(SubtleMustCopyIfIntegral(key_values(i))); |
302 | } |
303 | return OkStatus(); |
304 | } |
305 | |
306 | Status ImportValues(OpKernelContext* ctx, const Tensor& keys, |
307 | const Tensor& values) override { |
308 | return DoInsert(true, keys, values); |
309 | } |
310 | |
311 | Status ExportValues(OpKernelContext* ctx) override { |
312 | tf_shared_lock l(mu_); |
313 | int64_t size = table_.size(); |
314 | int64_t value_dim = value_shape_.dim_size(0); |
315 | |
316 | Tensor* keys; |
317 | Tensor* values; |
318 | TF_RETURN_IF_ERROR( |
319 | ctx->allocate_output("keys" , TensorShape({size}), &keys)); |
320 | TF_RETURN_IF_ERROR(ctx->allocate_output( |
321 | "values" , TensorShape({size, value_dim}), &values)); |
322 | ExportKeysAndValues(keys, values); |
323 | return OkStatus(); |
324 | } |
325 | |
326 | DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } |
327 | |
328 | DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } |
329 | |
330 | TensorShape key_shape() const final { return TensorShape(); } |
331 | |
332 | TensorShape value_shape() const override { return value_shape_; } |
333 | |
334 | int64_t MemoryUsed() const override { |
335 | int64_t ret = 0; |
336 | tf_shared_lock l(mu_); |
337 | for (unsigned i = 0; i < table_.bucket_count(); ++i) { |
338 | size_t bucket_size = table_.bucket_size(i); |
339 | if (bucket_size == 0) { |
340 | ret++; |
341 | } else { |
342 | ret += bucket_size; |
343 | } |
344 | } |
345 | return sizeof(MutableHashTableOfTensors) + ret; |
346 | } |
347 | |
348 | Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { |
349 | tf_shared_lock l(mu_); |
350 | int64_t size = table_.size(); |
351 | Tensor keys(key_dtype(), TensorShape({size})); |
352 | Tensor values(value_dtype(), TensorShape({size, value_shape_.dim_size(0)})); |
353 | ExportKeysAndValues(&keys, &values); |
354 | |
355 | // We set use_node_name_sharing with a unique node name so that the resource |
356 | // can outlive the MutableHashTableOfTensorsV2 kernel. This means that the |
357 | // lifetime of the resource will be tied to the lifetime of the resource |
358 | // manager it is created in. |
359 | // TODO(b/181695913): Provide a mechanism for deleting this resource |
360 | // earlier when appropriate. |
361 | Node* table = |
362 | ops::SourceOp("MutableHashTableOfTensorsV2" , |
363 | builder->opts() |
364 | .WithName(UniqueNodeName("MutableHashTableOfTensors" )) |
365 | .WithAttr("use_node_name_sharing" , true) |
366 | .WithAttr("key_dtype" , key_dtype()) |
367 | .WithAttr("value_dtype" , value_dtype()) |
368 | .WithAttr("value_shape" , value_shape_)); |
369 | Node* keys_node = ops::SourceOp( |
370 | "Const" , |
371 | builder->opts().WithAttr("dtype" , key_dtype()).WithAttr("value" , keys)); |
372 | Node* values_node = |
373 | ops::SourceOp("Const" , builder->opts() |
374 | .WithAttr("dtype" , value_dtype()) |
375 | .WithAttr("value" , values)); |
376 | Node* import_table = |
377 | ops::TernaryOp("LookupTableImportV2" , table, keys_node, values_node, |
378 | builder->opts() |
379 | .WithAttr("Tin" , key_dtype()) |
380 | .WithAttr("Tout" , value_dtype())); |
381 | *out = ops::UnaryOp("Identity" , table, |
382 | builder->opts().WithControlInput(import_table)); |
383 | return OkStatus(); |
384 | } |
385 | |
386 | private: |
387 | // Writes all keys and values into `keys` and `values`. `keys` and `values` |
388 | // must point to tensors of size `table_.size()`. |
389 | void ExportKeysAndValues(Tensor* keys, Tensor* values) const |
390 | TF_SHARED_LOCKS_REQUIRED(mu_) { |
391 | int64_t value_dim = value_shape_.dim_size(0); |
392 | auto keys_data = keys->flat<K>(); |
393 | auto values_data = values->matrix<V>(); |
394 | int64_t i = 0; |
395 | for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { |
396 | K key = it->first; |
397 | ValueArray value = it->second; |
398 | keys_data(i) = key; |
399 | for (int64_t j = 0; j < value_dim; j++) { |
400 | values_data(i, j) = value[j]; |
401 | } |
402 | } |
403 | } |
404 | |
405 | TensorShape value_shape_; |
406 | mutable mutex mu_; |
407 | typedef gtl::InlinedVector<V, 4> ValueArray; |
408 | std::unordered_map<K, ValueArray> table_ TF_GUARDED_BY(mu_); |
409 | }; |
410 | |
411 | namespace { |
412 | |
413 | template <typename T> |
414 | inline uint64 HashScalar(const T& key) { |
415 | return static_cast<uint64>(key); |
416 | } |
417 | |
418 | inline uint64 HashScalar(const tstring& key) { return Hash64(key); } |
419 | |
420 | // If the given shape is a scalar return {1} instead. Otherwise leave it alone. |
421 | TensorShape MaybeVectorizeShape(const TensorShape& shape) { |
422 | if (shape.dims() == 0) { |
423 | return TensorShape({1}); |
424 | } |
425 | return shape; |
426 | } |
427 | |
428 | } // namespace |
429 | |
430 | // Modeled after densehashtable in https://github.com/sparsehash/sparsehash |
431 | template <class K, class V> |
432 | class MutableDenseHashTable final : public LookupInterface { |
433 | public: |
434 | MutableDenseHashTable(OpKernelContext* ctx, OpKernel* kernel) { |
435 | OP_REQUIRES_OK( |
436 | ctx, GetNodeAttr(kernel->def(), "max_load_factor" , &max_load_factor_)); |
437 | OP_REQUIRES(ctx, max_load_factor_ > 0 && max_load_factor_ < 1, |
438 | errors::InvalidArgument( |
439 | "max_load_factor must be between 0 and 1, got: " , |
440 | max_load_factor_)); |
441 | |
442 | OP_REQUIRES_OK(ctx, |
443 | GetNodeAttr(kernel->def(), "value_shape" , &value_shape_)); |
444 | OP_REQUIRES(ctx, |
445 | TensorShapeUtils::IsScalar(value_shape_) || |
446 | TensorShapeUtils::IsVector(value_shape_), |
447 | errors::InvalidArgument( |
448 | "Empty value must be a scalar or a vector, got shape " , |
449 | value_shape_.DebugString())); |
450 | |
451 | const Tensor* empty_key_input; |
452 | OP_REQUIRES_OK(ctx, ctx->input("empty_key" , &empty_key_input)); |
453 | key_shape_ = empty_key_input->shape(); |
454 | OP_REQUIRES(ctx, |
455 | TensorShapeUtils::IsScalar(key_shape_) || |
456 | TensorShapeUtils::IsVector(key_shape_), |
457 | errors::InvalidArgument( |
458 | "Empty key must be a scalar or a vector, got shape " , |
459 | key_shape_.DebugString())); |
460 | empty_key_ = *empty_key_input; |
461 | empty_key_hash_ = HashKey( |
462 | empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}), |
463 | 0); |
464 | |
465 | const Tensor* deleted_key_input; |
466 | OP_REQUIRES_OK(ctx, ctx->input("deleted_key" , &deleted_key_input)); |
467 | OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()), |
468 | errors::InvalidArgument( |
469 | "Empty and deleted keys must have same shape, got shapes: " , |
470 | key_shape_.DebugString(), " and " , |
471 | deleted_key_input->shape().DebugString())); |
472 | deleted_key_ = *deleted_key_input; |
473 | deleted_key_hash_ = HashKey(deleted_key_input->template shaped<K, 2>( |
474 | {1, key_shape_.num_elements()}), |
475 | 0); |
476 | |
477 | if (empty_key_hash_ == deleted_key_hash_) { |
478 | const int64_t key_size = key_shape_.num_elements(); |
479 | const auto empty_key_matrix = |
480 | empty_key_.template shaped<K, 2>({1, key_size}); |
481 | const auto deleted_key_matrix = |
482 | deleted_key_.template shaped<K, 2>({1, key_size}); |
483 | OP_REQUIRES( |
484 | ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0), |
485 | errors::InvalidArgument("Empty and deleted keys cannot be equal" )); |
486 | } |
487 | |
488 | int64_t initial_num_buckets; |
489 | OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets" , |
490 | &initial_num_buckets)); |
491 | OP_REQUIRES_OK(ctx, AllocateBuckets(ctx, initial_num_buckets)); |
492 | } |
493 | |
494 | size_t size() const override TF_LOCKS_EXCLUDED(mu_) { |
495 | tf_shared_lock l(mu_); |
496 | return num_entries_; |
497 | } |
498 | |
499 | Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, |
500 | const Tensor& default_value) override TF_LOCKS_EXCLUDED(mu_) { |
501 | const int64_t num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); |
502 | const int64_t key_size = key_shape_.num_elements(); |
503 | const int64_t value_size = value_shape_.num_elements(); |
504 | if (key.NumElements() != num_elements * key_size) { |
505 | TensorShape expected_shape({num_elements}); |
506 | expected_shape.AppendShape(key_shape_); |
507 | return errors::InvalidArgument("Expected key shape " , |
508 | expected_shape.DebugString(), " got " , |
509 | key.shape().DebugString()); |
510 | } |
511 | const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); |
512 | auto value_matrix = value->shaped<V, 2>({num_elements, value_size}); |
513 | const auto default_flat = default_value.flat<V>(); |
514 | |
515 | tf_shared_lock l(mu_); |
516 | const auto key_buckets_matrix = key_buckets_.template matrix<K>(); |
517 | const auto value_buckets_matrix = value_buckets_.template matrix<V>(); |
518 | const auto empty_key_matrix = |
519 | empty_key_.template shaped<K, 2>({1, key_size}); |
520 | const auto deleted_key_matrix = |
521 | deleted_key_.template shaped<K, 2>({1, key_size}); |
522 | const int64_t bit_mask = num_buckets_ - 1; |
523 | // TODO(andreasst): parallelize using work_sharder |
524 | for (int64_t i = 0; i < num_elements; ++i) { |
525 | const uint64 key_hash = HashKey(key_matrix, i); |
526 | if (empty_key_hash_ == key_hash && |
527 | IsEqualKey(empty_key_matrix, 0, key_matrix, i)) { |
528 | return errors::InvalidArgument( |
529 | "Using the empty_key as a table key is not allowed" ); |
530 | } |
531 | if (deleted_key_hash_ == key_hash && |
532 | IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) { |
533 | return errors::InvalidArgument( |
534 | "Using the deleted_key as a table key is not allowed" ); |
535 | } |
536 | int64_t bucket_index = key_hash & bit_mask; |
537 | int64_t num_probes = 0; |
538 | while (true) { |
539 | if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) { |
540 | for (int64_t j = 0; j < value_size; ++j) { |
541 | // TODO(andreasst): check if we can get rid of SubtleMustCopy |
542 | // here and elsewhere in this file. |
543 | value_matrix(i, j) = |
544 | SubtleMustCopyIfIntegral(value_buckets_matrix(bucket_index, j)); |
545 | } |
546 | break; |
547 | } |
548 | if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_matrix, 0)) { |
549 | for (int64_t j = 0; j < value_size; ++j) { |
550 | value_matrix(i, j) = SubtleMustCopyIfIntegral(default_flat(j)); |
551 | } |
552 | break; |
553 | } |
554 | ++num_probes; |
555 | bucket_index = |
556 | (bucket_index + num_probes) & bit_mask; // quadratic probing |
557 | if (num_probes >= num_buckets_) { |
558 | return errors::Internal( |
559 | "Internal error in MutableDenseHashTable lookup" ); |
560 | } |
561 | } |
562 | } |
563 | return OkStatus(); |
564 | } |
565 | |
566 | Status Insert(OpKernelContext* ctx, const Tensor& key, |
567 | const Tensor& value) override TF_LOCKS_EXCLUDED(mu_) { |
568 | const int64_t batch_size = (key.dims() == 0) ? 1 : key.dim_size(0); |
569 | if (key.NumElements() != batch_size * key_shape_.num_elements()) { |
570 | TensorShape expected_shape({batch_size}); |
571 | expected_shape.AppendShape(key_shape_); |
572 | return errors::InvalidArgument("Expected key shape " , |
573 | expected_shape.DebugString(), " got " , |
574 | key.shape().DebugString()); |
575 | } |
576 | mutex_lock l(mu_); |
577 | // For simplicity we assume that all keys in the input result in inserts |
578 | // rather than updates. That means we may grow the table even though we |
579 | // don't need to. As long as the number of keys inserted in one call is |
580 | // small compared to the size of the map, the impact of this is minimal. |
581 | const int64_t pending_num_entries = num_entries_ + batch_size; |
582 | if (pending_num_entries > num_buckets_ * max_load_factor_) { |
583 | int64_t new_num_buckets = num_buckets_; |
584 | do { |
585 | new_num_buckets <<= 1; |
586 | } while (pending_num_entries > new_num_buckets * max_load_factor_); |
587 | TF_RETURN_IF_ERROR(Rebucket(ctx, new_num_buckets)); |
588 | } |
589 | return DoInsert(ctx, key, value, false); |
590 | } |
591 | |
592 | Status Remove(OpKernelContext* ctx, const Tensor& key) override |
593 | TF_LOCKS_EXCLUDED(mu_) { |
594 | if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) { |
595 | TensorShape expected_shape({key.dim_size(0)}); |
596 | expected_shape.AppendShape(key_shape_); |
597 | return errors::InvalidArgument("Expected key shape " , |
598 | expected_shape.DebugString(), " got " , |
599 | key.shape().DebugString()); |
600 | } |
601 | mutex_lock l(mu_); |
602 | return DoRemove(ctx, key); |
603 | } |
604 | |
605 | Status ImportValues(OpKernelContext* ctx, const Tensor& keys, |
606 | const Tensor& values) override TF_LOCKS_EXCLUDED(mu_) { |
607 | mutex_lock l(mu_); |
608 | num_buckets_ = keys.dim_size(0); |
609 | key_buckets_ = keys; |
610 | value_buckets_ = values; |
611 | // Count the number of keys that are not the empty_key or deleted_key. |
612 | // This requires iterating through the whole table but that is OK as we |
613 | // only execute it during checkpoint restore. |
614 | num_entries_ = 0; |
615 | const auto empty_key_tensor = |
616 | empty_key_.template shaped<K, 2>({1, key_shape_.num_elements()}); |
617 | const auto deleted_key_tensor = |
618 | deleted_key_.template shaped<K, 2>({1, key_shape_.num_elements()}); |
619 | const auto key_buckets_tensor = key_buckets_.template matrix<K>(); |
620 | for (int64_t i = 0; i < num_buckets_; ++i) { |
621 | if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) && |
622 | !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) { |
623 | ++num_entries_; |
624 | } |
625 | } |
626 | return OkStatus(); |
627 | } |
628 | |
629 | Status ExportValues(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) { |
630 | tf_shared_lock l(mu_); |
631 | TF_RETURN_IF_ERROR(ctx->set_output("keys" , key_buckets_)); |
632 | TF_RETURN_IF_ERROR(ctx->set_output("values" , value_buckets_)); |
633 | return OkStatus(); |
634 | } |
635 | |
636 | Status CheckKeyAndValueTensorsForImport(const Tensor& keys, |
637 | const Tensor& values) override { |
638 | TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values)); |
639 | TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape())); |
640 | |
641 | // The storage format in key_buckets_ and value_buckets_ is always vectors, |
642 | // even if the inputs are scalars. This is what eventually gets exported |
643 | // and is expected by the import method as well. |
644 | TensorShape key_shape = MaybeVectorizeShape(key_shape_); |
645 | TensorShape value_shape = MaybeVectorizeShape(value_shape_); |
646 | |
647 | // Compute the final expected shape of the value by starting with the shape |
648 | // of all keys, removing the dimensions particular to each key and then |
649 | // appending the shape of a single value. |
650 | TensorShape expected_value_shape = keys.shape(); |
651 | expected_value_shape.RemoveLastDims(key_shape.dims()); |
652 | expected_value_shape.AppendShape(value_shape); |
653 | if (values.shape() != expected_value_shape) { |
654 | return errors::InvalidArgument( |
655 | "Expected shape " , expected_value_shape.DebugString(), |
656 | " for value, got " , values.shape().DebugString()); |
657 | } |
658 | return OkStatus(); |
659 | } |
660 | |
661 | DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } |
662 | |
663 | DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } |
664 | |
665 | TensorShape key_shape() const override { return key_shape_; } |
666 | |
667 | TensorShape value_shape() const override { return value_shape_; } |
668 | |
669 | int64_t MemoryUsed() const override TF_LOCKS_EXCLUDED(mu_) { |
670 | tf_shared_lock l(mu_); |
671 | return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() + |
672 | value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes(); |
673 | } |
674 | |
675 | private: |
676 | Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value, |
677 | bool ignore_empty_and_deleted_key) |
678 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
679 | const int64_t num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); |
680 | const int64_t value_size = value_shape_.num_elements(); |
681 | const int64_t key_size = key_shape_.num_elements(); |
682 | const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); |
683 | auto value_matrix = value.shaped<V, 2>({num_elements, value_size}); |
684 | |
685 | auto key_buckets_matrix = key_buckets_.template matrix<K>(); |
686 | auto value_buckets_matrix = value_buckets_.template matrix<V>(); |
687 | const auto empty_key_tensor = |
688 | empty_key_.template shaped<K, 2>({1, key_size}); |
689 | const auto deleted_key_tensor = |
690 | deleted_key_.template shaped<K, 2>({1, key_size}); |
691 | const int64_t bit_mask = num_buckets_ - 1; |
692 | for (int64_t i = 0; i < num_elements; ++i) { |
693 | const uint64 key_hash = HashKey(key_matrix, i); |
694 | if (empty_key_hash_ == key_hash && |
695 | IsEqualKey(empty_key_tensor, 0, key_matrix, i)) { |
696 | if (ignore_empty_and_deleted_key) { |
697 | continue; |
698 | } |
699 | return errors::InvalidArgument( |
700 | "Using the empty_key as a table key is not allowed" ); |
701 | } |
702 | if (deleted_key_hash_ == key_hash && |
703 | IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) { |
704 | if (ignore_empty_and_deleted_key) { |
705 | continue; |
706 | } |
707 | return errors::InvalidArgument( |
708 | "Using the deleted_key as a table key is not allowed" ); |
709 | } |
710 | int64_t bucket_index = key_hash & bit_mask; |
711 | int64_t num_probes = 0; |
712 | while (true) { |
713 | if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) { |
714 | for (int64_t j = 0; j < value_size; ++j) { |
715 | value_buckets_matrix(bucket_index, j) = |
716 | SubtleMustCopyIfIntegral(value_matrix(i, j)); |
717 | } |
718 | break; |
719 | } |
720 | if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) || |
721 | IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor, |
722 | 0)) { |
723 | ++num_entries_; |
724 | for (int64_t j = 0; j < key_size; ++j) { |
725 | key_buckets_matrix(bucket_index, j) = |
726 | SubtleMustCopyIfIntegral(key_matrix(i, j)); |
727 | } |
728 | for (int64_t j = 0; j < value_size; ++j) { |
729 | value_buckets_matrix(bucket_index, j) = |
730 | SubtleMustCopyIfIntegral(value_matrix(i, j)); |
731 | } |
732 | break; |
733 | } |
734 | ++num_probes; |
735 | bucket_index = |
736 | (bucket_index + num_probes) & bit_mask; // quadratic probing |
737 | if (num_probes >= num_buckets_) { |
738 | return errors::Internal( |
739 | "Internal error in MutableDenseHashTable insert" ); |
740 | } |
741 | } |
742 | } |
743 | return OkStatus(); |
744 | } |
745 | |
746 | Status DoRemove(OpKernelContext* ctx, const Tensor& key) |
747 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
748 | const int64_t num_elements = key.dim_size(0); |
749 | const int64_t key_size = key_shape_.num_elements(); |
750 | const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); |
751 | |
752 | auto key_buckets_matrix = key_buckets_.template matrix<K>(); |
753 | const auto empty_key_tensor = |
754 | empty_key_.template shaped<K, 2>({1, key_size}); |
755 | const auto deleted_key_tensor = |
756 | deleted_key_.template shaped<K, 2>({1, key_size}); |
757 | const auto deleted_key_flat = deleted_key_.template flat<K>(); |
758 | const int64_t bit_mask = num_buckets_ - 1; |
759 | for (int64_t i = 0; i < num_elements; ++i) { |
760 | const uint64 key_hash = HashKey(key_matrix, i); |
761 | if (empty_key_hash_ == key_hash && |
762 | IsEqualKey(empty_key_tensor, 0, key_matrix, i)) { |
763 | return errors::InvalidArgument( |
764 | "Using the empty_key as a table key is not allowed" ); |
765 | } |
766 | if (deleted_key_hash_ == key_hash && |
767 | IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) { |
768 | return errors::InvalidArgument( |
769 | "Using the deleted_key as a table key is not allowed" ); |
770 | } |
771 | int64_t bucket_index = key_hash & bit_mask; |
772 | int64_t num_probes = 0; |
773 | while (true) { |
774 | if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) { |
775 | --num_entries_; |
776 | for (int64_t j = 0; j < key_size; ++j) { |
777 | key_buckets_matrix(bucket_index, j) = |
778 | SubtleMustCopyIfIntegral(deleted_key_flat(j)); |
779 | } |
780 | break; |
781 | } |
782 | if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) { |
783 | break; |
784 | } |
785 | ++num_probes; |
786 | bucket_index = |
787 | (bucket_index + num_probes) & bit_mask; // quadratic probing |
788 | if (num_probes >= num_buckets_) { |
789 | return errors::Internal( |
790 | "Internal error in MutableDenseHashTable remove" ); |
791 | } |
792 | } |
793 | } |
794 | return OkStatus(); |
795 | } |
796 | |
797 | Status AllocateBuckets(OpKernelContext* ctx, int64_t new_num_buckets) |
798 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
799 | if (new_num_buckets < 4 || |
800 | ((new_num_buckets & (new_num_buckets - 1)) != 0)) { |
801 | return errors::InvalidArgument( |
802 | "Number of buckets must be at least 4 and a power of 2, got: " , |
803 | new_num_buckets); |
804 | } |
805 | num_buckets_ = new_num_buckets; |
806 | num_entries_ = 0; |
807 | |
808 | const int64_t key_size = key_shape_.num_elements(); |
809 | TF_RETURN_IF_ERROR(ctx->allocate_temp( |
810 | key_dtype(), TensorShape({num_buckets_, key_size}), &key_buckets_)); |
811 | auto key_buckets_matrix = key_buckets_.matrix<K>(); |
812 | const auto empty_key_flat = empty_key_.template flat<K>(); |
813 | for (int64_t i = 0; i < num_buckets_; ++i) { |
814 | for (int64_t j = 0; j < key_size; ++j) { |
815 | key_buckets_matrix(i, j) = empty_key_flat(j); |
816 | } |
817 | } |
818 | |
819 | const int64_t value_size = value_shape_.num_elements(); |
820 | |
821 | TF_RETURN_IF_ERROR(ctx->allocate_temp( |
822 | value_dtype(), TensorShape({num_buckets_, value_size}), |
823 | &value_buckets_)); |
824 | auto value_buckets_matrix = value_buckets_.matrix<V>(); |
825 | for (int64_t i = 0; i < num_buckets_; ++i) { |
826 | for (int64_t j = 0; j < value_size; ++j) { |
827 | // Initialize values to the default value for the type to avoid |
828 | // exposing uninitialized memory in ExportValues(). |
829 | value_buckets_matrix(i, j) = V(); |
830 | } |
831 | } |
832 | return OkStatus(); |
833 | } |
834 | |
835 | Status Rebucket(OpKernelContext* ctx, int64_t num_new_buckets) |
836 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
837 | Tensor old_key_buckets = key_buckets_; |
838 | Tensor old_value_buckets = value_buckets_; |
839 | TF_RETURN_IF_ERROR(AllocateBuckets(ctx, num_new_buckets)); |
840 | return DoInsert(ctx, old_key_buckets, old_value_buckets, true); |
841 | } |
842 | |
843 | uint64 HashKey(typename TTypes<K>::ConstMatrix key, int64_t index) const { |
844 | if (key_shape_.num_elements() == 1) { |
845 | return HashScalar(key(index, 0)); |
846 | } |
847 | uint64 result = 0; |
848 | for (int64_t i = 0; i < key_shape_.num_elements(); ++i) { |
849 | result = Hash64Combine(result, HashScalar(key(index, i))); |
850 | } |
851 | return result; |
852 | } |
853 | |
854 | // Use a template to allow this function to be used both with Matrix and |
855 | // ConstMatrix types. |
856 | template <typename MT2> |
857 | bool IsEqualKey(typename TTypes<K>::Matrix tensor1, int64_t index1, |
858 | MT2 tensor2, int64_t index2) const { |
859 | for (int64_t i = 0; i < key_shape_.num_elements(); ++i) { |
860 | if (tensor1(index1, i) != tensor2(index2, i)) { |
861 | return false; |
862 | } |
863 | } |
864 | return true; |
865 | } |
866 | |
867 | TensorShape key_shape_; |
868 | TensorShape value_shape_; |
869 | float max_load_factor_; |
870 | mutable mutex mu_; |
871 | int64_t num_entries_ TF_GUARDED_BY(mu_); |
872 | int64_t num_buckets_ TF_GUARDED_BY(mu_); |
873 | Tensor key_buckets_ TF_GUARDED_BY(mu_); |
874 | Tensor value_buckets_ TF_GUARDED_BY(mu_); |
875 | Tensor empty_key_; |
876 | uint64 empty_key_hash_; |
877 | Tensor deleted_key_; |
878 | uint64 deleted_key_hash_; |
879 | }; |
880 | |
881 | } // namespace lookup |
882 | |
883 | // Base class for kernels that take a LookupTable handle as the 0th input. |
884 | class LookupTableOpKernel : public OpKernel { |
885 | public: |
886 | explicit LookupTableOpKernel(OpKernelConstruction* ctx) |
887 | : OpKernel(ctx), |
888 | expected_input_0_(ctx->input_type(0) == DT_RESOURCE ? DT_RESOURCE |
889 | : DT_STRING_REF) {} |
890 | |
891 | protected: |
892 | Status GetTable(OpKernelContext* ctx, lookup::LookupInterface** table) { |
893 | if (expected_input_0_ == DT_RESOURCE) { |
894 | return GetResourceLookupTable("table_handle" , ctx, table); |
895 | } else { |
896 | return GetReferenceLookupTable("table_handle" , ctx, table); |
897 | } |
898 | } |
899 | |
900 | // Input 0 could be a STRING_REF or a RESOURCE |
901 | const DataType expected_input_0_; |
902 | }; |
903 | |
904 | // Table lookup op. Perform the lookup operation on the given table. |
905 | class LookupTableFindOp : public LookupTableOpKernel { |
906 | public: |
907 | using LookupTableOpKernel::LookupTableOpKernel; |
908 | |
909 | void Compute(OpKernelContext* ctx) override { |
910 | lookup::LookupInterface* table; |
911 | OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); |
912 | core::ScopedUnref unref_me(table); |
913 | |
914 | DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(), |
915 | table->value_dtype()}; |
916 | DataTypeVector expected_outputs = {table->value_dtype()}; |
917 | OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); |
918 | |
919 | const Tensor& key = ctx->input(1); |
920 | const Tensor& default_value = ctx->input(2); |
921 | OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value)); |
922 | |
923 | TensorShape output_shape = key.shape(); |
924 | output_shape.RemoveLastDims(table->key_shape().dims()); |
925 | output_shape.AppendShape(table->value_shape()); |
926 | Tensor* out; |
927 | OP_REQUIRES_OK(ctx, ctx->allocate_output("values" , output_shape, &out)); |
928 | |
929 | OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value)); |
930 | } |
931 | }; |
932 | |
933 | REGISTER_KERNEL_BUILDER(Name("LookupTableFind" ).Device(DEVICE_CPU), |
934 | LookupTableFindOp); |
935 | REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2" ).Device(DEVICE_CPU), |
936 | LookupTableFindOp); |
937 | |
938 | // Table insert op. |
939 | class LookupTableInsertOp : public LookupTableOpKernel { |
940 | public: |
941 | using LookupTableOpKernel::LookupTableOpKernel; |
942 | |
943 | void Compute(OpKernelContext* ctx) override { |
944 | lookup::LookupInterface* table; |
945 | OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); |
946 | core::ScopedUnref unref_me(table); |
947 | |
948 | DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(), |
949 | table->value_dtype()}; |
950 | OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); |
951 | |
952 | const Tensor& keys = ctx->input(1); |
953 | const Tensor& values = ctx->input(2); |
954 | OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForInsert(keys, values)); |
955 | |
956 | int64_t memory_used_before = 0; |
957 | if (ctx->track_allocations()) { |
958 | memory_used_before = table->MemoryUsed(); |
959 | } |
960 | OP_REQUIRES_OK(ctx, table->Insert(ctx, keys, values)); |
961 | if (ctx->track_allocations()) { |
962 | ctx->record_persistent_memory_allocation(table->MemoryUsed() - |
963 | memory_used_before); |
964 | } |
965 | } |
966 | }; |
967 | |
968 | REGISTER_KERNEL_BUILDER(Name("LookupTableInsert" ).Device(DEVICE_CPU), |
969 | LookupTableInsertOp); |
970 | REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2" ).Device(DEVICE_CPU), |
971 | LookupTableInsertOp); |
972 | |
973 | // Table remove op. |
974 | class LookupTableRemoveOp : public LookupTableOpKernel { |
975 | public: |
976 | using LookupTableOpKernel::LookupTableOpKernel; |
977 | |
978 | void Compute(OpKernelContext* ctx) override { |
979 | lookup::LookupInterface* table; |
980 | OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); |
981 | core::ScopedUnref unref_me(table); |
982 | |
983 | DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype()}; |
984 | OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); |
985 | |
986 | const Tensor& key = ctx->input(1); |
987 | OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key)); |
988 | |
989 | int64_t memory_used_before = 0; |
990 | if (ctx->track_allocations()) { |
991 | memory_used_before = table->MemoryUsed(); |
992 | } |
993 | OP_REQUIRES_OK(ctx, table->Remove(ctx, key)); |
994 | if (ctx->track_allocations()) { |
995 | ctx->record_persistent_memory_allocation(table->MemoryUsed() - |
996 | memory_used_before); |
997 | } |
998 | } |
999 | }; |
1000 | |
1001 | REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2" ).Device(DEVICE_CPU), |
1002 | LookupTableRemoveOp); |
1003 | |
1004 | // Op that returns the size of the given table. |
1005 | class LookupTableSizeOp : public LookupTableOpKernel { |
1006 | public: |
1007 | using LookupTableOpKernel::LookupTableOpKernel; |
1008 | |
1009 | void Compute(OpKernelContext* ctx) override { |
1010 | lookup::LookupInterface* table; |
1011 | OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); |
1012 | core::ScopedUnref unref_me(table); |
1013 | |
1014 | Tensor* out; |
1015 | OP_REQUIRES_OK(ctx, ctx->allocate_output("size" , TensorShape({}), &out)); |
1016 | out->flat<int64_t>().setConstant(table->size()); |
1017 | } |
1018 | }; |
1019 | |
1020 | REGISTER_KERNEL_BUILDER(Name("LookupTableSize" ).Device(DEVICE_CPU), |
1021 | LookupTableSizeOp); |
1022 | REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2" ).Device(DEVICE_CPU), |
1023 | LookupTableSizeOp); |
1024 | |
1025 | // Op that outputs tensors of all keys and all values. |
1026 | class LookupTableExportOp : public LookupTableOpKernel { |
1027 | public: |
1028 | using LookupTableOpKernel::LookupTableOpKernel; |
1029 | |
1030 | void Compute(OpKernelContext* ctx) override { |
1031 | lookup::LookupInterface* table; |
1032 | OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); |
1033 | core::ScopedUnref unref_me(table); |
1034 | |
1035 | OP_REQUIRES_OK(ctx, table->ExportValues(ctx)); |
1036 | } |
1037 | }; |
1038 | |
1039 | REGISTER_KERNEL_BUILDER(Name("LookupTableExport" ).Device(DEVICE_CPU), |
1040 | LookupTableExportOp); |
1041 | REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2" ).Device(DEVICE_CPU), |
1042 | LookupTableExportOp); |
1043 | |
1044 | // Clear the table and insert data. |
1045 | class LookupTableImportOp : public LookupTableOpKernel { |
1046 | public: |
1047 | using LookupTableOpKernel::LookupTableOpKernel; |
1048 | |
1049 | void Compute(OpKernelContext* ctx) override { |
1050 | lookup::LookupInterface* table; |
1051 | OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); |
1052 | core::ScopedUnref unref_me(table); |
1053 | |
1054 | DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(), |
1055 | table->value_dtype()}; |
1056 | OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); |
1057 | |
1058 | const Tensor& keys = ctx->input(1); |
1059 | const Tensor& values = ctx->input(2); |
1060 | OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values)); |
1061 | |
1062 | int memory_used_before = 0; |
1063 | if (ctx->track_allocations()) { |
1064 | memory_used_before = table->MemoryUsed(); |
1065 | } |
1066 | OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); |
1067 | if (ctx->track_allocations()) { |
1068 | ctx->record_persistent_memory_allocation(table->MemoryUsed() - |
1069 | memory_used_before); |
1070 | } |
1071 | } |
1072 | }; |
1073 | |
1074 | REGISTER_KERNEL_BUILDER(Name("LookupTableImport" ).Device(DEVICE_CPU), |
1075 | LookupTableImportOp); |
1076 | REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2" ).Device(DEVICE_CPU), |
1077 | LookupTableImportOp); |
1078 | |
1079 | // Register the HashTable op with the currently supported key and value types. |
1080 | #define REGISTER_KERNEL(key_dtype, value_dtype) \ |
1081 | REGISTER_KERNEL_BUILDER( \ |
1082 | Name("HashTable") \ |
1083 | .Device(DEVICE_CPU) \ |
1084 | .TypeConstraint<key_dtype>("key_dtype") \ |
1085 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1086 | LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ |
1087 | value_dtype>) \ |
1088 | REGISTER_KERNEL_BUILDER( \ |
1089 | Name("HashTableV2") \ |
1090 | .Device(DEVICE_CPU) \ |
1091 | .TypeConstraint<key_dtype>("key_dtype") \ |
1092 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1093 | LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ |
1094 | value_dtype>) \ |
1095 | REGISTER_KERNEL_BUILDER( \ |
1096 | Name("AnonymousHashTable") \ |
1097 | .Device(DEVICE_CPU) \ |
1098 | .TypeConstraint<key_dtype>("key_dtype") \ |
1099 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1100 | AnonymousLookupTableOp<lookup::HashTable<key_dtype, value_dtype>, \ |
1101 | key_dtype, value_dtype>) |
1102 | |
1103 | REGISTER_KERNEL(int32, double); |
1104 | REGISTER_KERNEL(int32, float); |
1105 | REGISTER_KERNEL(int32, int32); |
1106 | REGISTER_KERNEL(int32, tstring); |
1107 | REGISTER_KERNEL(int64_t, double); |
1108 | REGISTER_KERNEL(int64_t, float); |
1109 | REGISTER_KERNEL(int64_t, int32); |
1110 | REGISTER_KERNEL(int64_t, int64_t); |
1111 | REGISTER_KERNEL(int64_t, tstring); |
1112 | REGISTER_KERNEL(tstring, bool); |
1113 | REGISTER_KERNEL(tstring, double); |
1114 | REGISTER_KERNEL(tstring, float); |
1115 | REGISTER_KERNEL(tstring, int32); |
1116 | REGISTER_KERNEL(tstring, int64_t); |
1117 | REGISTER_KERNEL(tstring, tstring); |
1118 | |
1119 | #undef REGISTER_KERNEL |
1120 | |
1121 | // Register the MutableHashTable op. |
1122 | #define REGISTER_KERNEL(key_dtype, value_dtype) \ |
1123 | REGISTER_KERNEL_BUILDER( \ |
1124 | Name("MutableHashTable") \ |
1125 | .Device(DEVICE_CPU) \ |
1126 | .TypeConstraint<key_dtype>("key_dtype") \ |
1127 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1128 | LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ |
1129 | key_dtype, value_dtype>) \ |
1130 | REGISTER_KERNEL_BUILDER( \ |
1131 | Name("MutableHashTableV2") \ |
1132 | .Device(DEVICE_CPU) \ |
1133 | .TypeConstraint<key_dtype>("key_dtype") \ |
1134 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1135 | LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ |
1136 | key_dtype, value_dtype>) \ |
1137 | REGISTER_KERNEL_BUILDER( \ |
1138 | Name("AnonymousMutableHashTable") \ |
1139 | .Device(DEVICE_CPU) \ |
1140 | .TypeConstraint<key_dtype>("key_dtype") \ |
1141 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1142 | AnonymousLookupTableOp< \ |
1143 | lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ |
1144 | key_dtype, value_dtype>) |
1145 | |
1146 | REGISTER_KERNEL(int32, double); |
1147 | REGISTER_KERNEL(int32, float); |
1148 | REGISTER_KERNEL(int32, int32); |
1149 | REGISTER_KERNEL(int64_t, double); |
1150 | REGISTER_KERNEL(int64_t, float); |
1151 | REGISTER_KERNEL(int64_t, int32); |
1152 | REGISTER_KERNEL(int64_t, int64_t); |
1153 | REGISTER_KERNEL(int64_t, tstring); |
1154 | REGISTER_KERNEL(int64_t, Variant); |
1155 | REGISTER_KERNEL(tstring, bool); |
1156 | REGISTER_KERNEL(tstring, double); |
1157 | REGISTER_KERNEL(tstring, float); |
1158 | REGISTER_KERNEL(tstring, int32); |
1159 | REGISTER_KERNEL(tstring, int64_t); |
1160 | |
1161 | #undef REGISTER_KERNEL |
1162 | |
1163 | // Register the MutableHashTableOfTensors op. |
1164 | #define REGISTER_KERNEL(key_dtype, value_dtype) \ |
1165 | REGISTER_KERNEL_BUILDER( \ |
1166 | Name("MutableHashTableOfTensors") \ |
1167 | .Device(DEVICE_CPU) \ |
1168 | .TypeConstraint<key_dtype>("key_dtype") \ |
1169 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1170 | LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ |
1171 | key_dtype, value_dtype>) \ |
1172 | REGISTER_KERNEL_BUILDER( \ |
1173 | Name("MutableHashTableOfTensorsV2") \ |
1174 | .Device(DEVICE_CPU) \ |
1175 | .TypeConstraint<key_dtype>("key_dtype") \ |
1176 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1177 | LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ |
1178 | key_dtype, value_dtype>) \ |
1179 | REGISTER_KERNEL_BUILDER( \ |
1180 | Name("AnonymousMutableHashTableOfTensors") \ |
1181 | .Device(DEVICE_CPU) \ |
1182 | .TypeConstraint<key_dtype>("key_dtype") \ |
1183 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1184 | AnonymousLookupTableOp< \ |
1185 | lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ |
1186 | key_dtype, value_dtype>) |
1187 | |
1188 | REGISTER_KERNEL(int32, double); |
1189 | REGISTER_KERNEL(int32, float); |
1190 | REGISTER_KERNEL(int32, int32); |
1191 | REGISTER_KERNEL(int64_t, double); |
1192 | REGISTER_KERNEL(int64_t, float); |
1193 | REGISTER_KERNEL(int64_t, int32); |
1194 | REGISTER_KERNEL(int64_t, int64_t); |
1195 | REGISTER_KERNEL(int64_t, tstring); |
1196 | REGISTER_KERNEL(tstring, bool); |
1197 | REGISTER_KERNEL(tstring, double); |
1198 | REGISTER_KERNEL(tstring, float); |
1199 | REGISTER_KERNEL(tstring, int32); |
1200 | REGISTER_KERNEL(tstring, int64_t); |
1201 | |
1202 | #undef REGISTER_KERNEL |
1203 | |
1204 | // Register the MutableDenseHashTable op. |
1205 | #define REGISTER_KERNEL(key_dtype, value_dtype) \ |
1206 | REGISTER_KERNEL_BUILDER( \ |
1207 | Name("MutableDenseHashTable") \ |
1208 | .Device(DEVICE_CPU) \ |
1209 | .TypeConstraint<key_dtype>("key_dtype") \ |
1210 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1211 | LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ |
1212 | key_dtype, value_dtype>) \ |
1213 | REGISTER_KERNEL_BUILDER( \ |
1214 | Name("MutableDenseHashTableV2") \ |
1215 | .Device(DEVICE_CPU) \ |
1216 | .TypeConstraint<key_dtype>("key_dtype") \ |
1217 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1218 | LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ |
1219 | key_dtype, value_dtype>) \ |
1220 | REGISTER_KERNEL_BUILDER( \ |
1221 | Name("AnonymousMutableDenseHashTable") \ |
1222 | .Device(DEVICE_CPU) \ |
1223 | .TypeConstraint<key_dtype>("key_dtype") \ |
1224 | .TypeConstraint<value_dtype>("value_dtype"), \ |
1225 | AnonymousLookupTableOp< \ |
1226 | lookup::MutableDenseHashTable<key_dtype, value_dtype>, key_dtype, \ |
1227 | value_dtype>) |
1228 | |
1229 | REGISTER_KERNEL(int32, double); |
1230 | REGISTER_KERNEL(int32, float); |
1231 | REGISTER_KERNEL(int32, int32); |
1232 | REGISTER_KERNEL(int64_t, bool); |
1233 | REGISTER_KERNEL(int64_t, double); |
1234 | REGISTER_KERNEL(int64_t, float); |
1235 | REGISTER_KERNEL(int64_t, int32); |
1236 | REGISTER_KERNEL(int64_t, int64_t); |
1237 | REGISTER_KERNEL(int64_t, Variant); |
1238 | REGISTER_KERNEL(tstring, bool); |
1239 | REGISTER_KERNEL(tstring, double); |
1240 | REGISTER_KERNEL(tstring, float); |
1241 | REGISTER_KERNEL(tstring, int32); |
1242 | REGISTER_KERNEL(tstring, int64_t); |
1243 | REGISTER_KERNEL(tstring, ResourceHandle); |
1244 | |
1245 | #undef REGISTER_KERNEL |
1246 | |
1247 | } // namespace tensorflow |
1248 | |