1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/lookup_table_op.h"
17#define EIGEN_USE_THREADS
18
19#include <string>
20#include <type_traits>
21#include <utility>
22
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/framework/types.h"
25#include "tensorflow/core/framework/variant.h"
26#include "tensorflow/core/kernels/initializable_lookup_table.h"
27#include "tensorflow/core/lib/gtl/inlined_vector.h"
28#include "tensorflow/core/lib/hash/hash.h"
29#include "tensorflow/core/platform/random.h"
30
31namespace tensorflow {
32namespace lookup {
33
34std::string UniqueNodeName(const std::string& base) {
35 static std::atomic<int64_t> counter(0);
36 return strings::StrCat(base, "/", counter.fetch_add(1), "/", random::New64());
37}
38
39// Lookup table that wraps an unordered_map, where the key and value data type
40// is specified. Each individual value must be a scalar. If vector values are
41// required, use MutableHashTableOfTensors.
42//
43// This table is mutable and thread safe - Insert can be called at any time.
44//
45// Sample use case:
46//
47// MutableHashTableOfScalars<int64, int64> table; // int64 -> int64.
48// // Populate the table, elements could be added in one or multiple calls.
49// table.Insert(key_tensor, value_tensor); // Populate the table.
50//
51// table.Find(in_t, &out_t, default_t)
52//
53template <class K, class V>
54class MutableHashTableOfScalars final : public LookupInterface {
55 public:
56 MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
57
58 size_t size() const override {
59 tf_shared_lock l(mu_);
60 return table_.size();
61 }
62
63 Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
64 const Tensor& default_value) override {
65 const auto key_values = key.flat<K>();
66 auto value_values = value->flat<V>();
67 const auto default_flat = default_value.flat<V>();
68
69 int64_t total = value_values.size();
70 int64_t default_total = default_flat.size();
71 bool is_full_size_default = (total == default_total);
72
73 tf_shared_lock l(mu_);
74 for (int64_t i = 0; i < key_values.size(); ++i) {
75 // is_full_size_default is true:
76 // Each key has an independent default value, key_values(i)
77 // corresponding uses default_flat(i) as its default value.
78 //
79 // is_full_size_default is false:
80 // All keys will share the default_flat(0) as default value.
81 value_values(i) = gtl::FindWithDefault(
82 table_, SubtleMustCopyIfIntegral(key_values(i)),
83 is_full_size_default ? default_flat(i) : default_flat(0));
84 }
85
86 return OkStatus();
87 }
88
89 Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
90 const auto key_values = keys.flat<K>();
91 const auto value_values = values.flat<V>();
92
93 mutex_lock l(mu_);
94 if (clear) {
95 table_.clear();
96 }
97 for (int64_t i = 0; i < key_values.size(); ++i) {
98 gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)),
99 SubtleMustCopyIfIntegral(value_values(i)));
100 }
101 return OkStatus();
102 }
103
104 Status Insert(OpKernelContext* ctx, const Tensor& keys,
105 const Tensor& values) override {
106 return DoInsert(false, keys, values);
107 }
108
109 Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
110 const auto key_values = keys.flat<K>();
111
112 mutex_lock l(mu_);
113 for (int64_t i = 0; i < key_values.size(); ++i) {
114 table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
115 }
116 return OkStatus();
117 }
118
119 Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
120 const Tensor& values) override {
121 return DoInsert(true, keys, values);
122 }
123
124 Status ExportValues(OpKernelContext* ctx) override {
125 tf_shared_lock l(mu_);
126 int64_t size = table_.size();
127
128 Tensor* keys;
129 Tensor* values;
130 TF_RETURN_IF_ERROR(
131 ctx->allocate_output("keys", TensorShape({size}), &keys));
132 TF_RETURN_IF_ERROR(
133 ctx->allocate_output("values", TensorShape({size}), &values));
134 ExportKeysAndValues(keys, values);
135 return OkStatus();
136 }
137
138 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
139
140 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
141
142 TensorShape key_shape() const final { return TensorShape(); }
143
144 TensorShape value_shape() const override { return TensorShape(); }
145
146 int64_t MemoryUsed() const override {
147 int64_t ret = 0;
148 tf_shared_lock l(mu_);
149 for (unsigned i = 0; i < table_.bucket_count(); ++i) {
150 size_t bucket_size = table_.bucket_size(i);
151 if (bucket_size == 0) {
152 ret++;
153 } else {
154 ret += bucket_size;
155 }
156 }
157 return sizeof(MutableHashTableOfScalars) + ret;
158 }
159
160 Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override {
161 tf_shared_lock l(mu_);
162 int64_t size = table_.size();
163 Tensor keys(key_dtype(), TensorShape({size}));
164 Tensor values(value_dtype(), TensorShape({size}));
165 ExportKeysAndValues(&keys, &values);
166
167 // We set use_node_name_sharing with a unique node name so that the resource
168 // can outlive the MutableHashTableV2 kernel. This means that the lifetime
169 // of the resource will be tied to the lifetime of the resource manager it
170 // is created in.
171 // TODO(b/181695913): Provide a mechanism for deleting this resource
172 // earlier when appropriate.
173 Node* table = ops::SourceOp(
174 "MutableHashTableV2",
175 builder->opts()
176 .WithName(UniqueNodeName("MutableHashTableFromGraphDef"))
177 .WithAttr("use_node_name_sharing", true)
178 .WithAttr("key_dtype", key_dtype())
179 .WithAttr("value_dtype", value_dtype()));
180 Node* keys_node = ops::SourceOp(
181 "Const",
182 builder->opts().WithAttr("dtype", key_dtype()).WithAttr("value", keys));
183 Node* values_node =
184 ops::SourceOp("Const", builder->opts()
185 .WithAttr("dtype", value_dtype())
186 .WithAttr("value", values));
187 Node* import_table =
188 ops::TernaryOp("LookupTableImportV2", table, keys_node, values_node,
189 builder->opts()
190 .WithAttr("Tin", key_dtype())
191 .WithAttr("Tout", value_dtype()));
192 *out = ops::UnaryOp("Identity", table,
193 builder->opts().WithControlInput(import_table));
194 return OkStatus();
195 }
196
197 private:
198 // Writes all keys and values into `keys` and `values`. `keys` and `values`
199 // must point to tensors of size `table_.size()`.
200 void ExportKeysAndValues(Tensor* keys, Tensor* values) const
201 TF_SHARED_LOCKS_REQUIRED(mu_) {
202 auto keys_data = keys->flat<K>();
203 auto values_data = values->flat<V>();
204 int64_t i = 0;
205 for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
206 keys_data(i) = it->first;
207 values_data(i) = it->second;
208 }
209 }
210
211 mutable mutex mu_;
212 std::unordered_map<K, V> table_ TF_GUARDED_BY(mu_);
213};
214
215// Lookup table that wraps an unordered_map. Behaves identical to
216// MutableHashTableOfScalars except that each value must be a vector.
217template <class K, class V>
218class MutableHashTableOfTensors final : public LookupInterface {
219 public:
220 MutableHashTableOfTensors(OpKernelContext* ctx, OpKernel* kernel) {
221 OP_REQUIRES_OK(ctx,
222 GetNodeAttr(kernel->def(), "value_shape", &value_shape_));
223 OP_REQUIRES(
224 ctx, TensorShapeUtils::IsVector(value_shape_),
225 errors::InvalidArgument("Default value must be a vector, got shape ",
226 value_shape_.DebugString()));
227 }
228
229 size_t size() const override {
230 tf_shared_lock l(mu_);
231 return table_.size();
232 }
233
234 Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
235 const Tensor& default_value) override {
236 const auto default_flat = default_value.flat_inner_dims<V, 2>();
237 const auto key_values = key.flat<K>();
238 auto value_values = value->flat_inner_dims<V, 2>();
239 int64_t value_dim = value_shape_.dim_size(0);
240
241 int64_t total = value_values.size();
242 int64_t default_total = default_flat.size();
243 bool is_full_size_default = (total == default_total);
244
245 tf_shared_lock l(mu_);
246 for (int64_t i = 0; i < key_values.size(); ++i) {
247 ValueArray* value_vec =
248 gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
249 if (value_vec != nullptr) {
250 for (int64_t j = 0; j < value_dim; j++) {
251 value_values(i, j) = value_vec->at(j);
252 }
253 } else {
254 // is_full_size_default is true:
255 // Each key has an independent default value, key_values(i)
256 // corresponding uses default_flat(i) as its default value.
257 //
258 // is_full_size_default is false:
259 // All keys will share the default_flat(0) as default value.
260 for (int64_t j = 0; j < value_dim; j++) {
261 value_values(i, j) =
262 is_full_size_default ? default_flat(i, j) : default_flat(0, j);
263 }
264 }
265 }
266
267 return OkStatus();
268 }
269
270 Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
271 const auto key_values = keys.flat<K>();
272 const auto value_values = values.flat_inner_dims<V, 2>();
273 int64_t value_dim = value_shape_.dim_size(0);
274
275 mutex_lock l(mu_);
276 if (clear) {
277 table_.clear();
278 }
279 for (int64_t i = 0; i < key_values.size(); ++i) {
280 ValueArray value_vec;
281 for (int64_t j = 0; j < value_dim; j++) {
282 V value = value_values(i, j);
283 value_vec.push_back(value);
284 }
285 gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)),
286 value_vec);
287 }
288 return OkStatus();
289 }
290
291 Status Insert(OpKernelContext* ctx, const Tensor& keys,
292 const Tensor& values) override {
293 return DoInsert(false, keys, values);
294 }
295
296 Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
297 const auto key_values = keys.flat<K>();
298
299 mutex_lock l(mu_);
300 for (int64_t i = 0; i < key_values.size(); ++i) {
301 table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
302 }
303 return OkStatus();
304 }
305
306 Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
307 const Tensor& values) override {
308 return DoInsert(true, keys, values);
309 }
310
311 Status ExportValues(OpKernelContext* ctx) override {
312 tf_shared_lock l(mu_);
313 int64_t size = table_.size();
314 int64_t value_dim = value_shape_.dim_size(0);
315
316 Tensor* keys;
317 Tensor* values;
318 TF_RETURN_IF_ERROR(
319 ctx->allocate_output("keys", TensorShape({size}), &keys));
320 TF_RETURN_IF_ERROR(ctx->allocate_output(
321 "values", TensorShape({size, value_dim}), &values));
322 ExportKeysAndValues(keys, values);
323 return OkStatus();
324 }
325
326 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
327
328 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
329
330 TensorShape key_shape() const final { return TensorShape(); }
331
332 TensorShape value_shape() const override { return value_shape_; }
333
334 int64_t MemoryUsed() const override {
335 int64_t ret = 0;
336 tf_shared_lock l(mu_);
337 for (unsigned i = 0; i < table_.bucket_count(); ++i) {
338 size_t bucket_size = table_.bucket_size(i);
339 if (bucket_size == 0) {
340 ret++;
341 } else {
342 ret += bucket_size;
343 }
344 }
345 return sizeof(MutableHashTableOfTensors) + ret;
346 }
347
348 Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override {
349 tf_shared_lock l(mu_);
350 int64_t size = table_.size();
351 Tensor keys(key_dtype(), TensorShape({size}));
352 Tensor values(value_dtype(), TensorShape({size, value_shape_.dim_size(0)}));
353 ExportKeysAndValues(&keys, &values);
354
355 // We set use_node_name_sharing with a unique node name so that the resource
356 // can outlive the MutableHashTableOfTensorsV2 kernel. This means that the
357 // lifetime of the resource will be tied to the lifetime of the resource
358 // manager it is created in.
359 // TODO(b/181695913): Provide a mechanism for deleting this resource
360 // earlier when appropriate.
361 Node* table =
362 ops::SourceOp("MutableHashTableOfTensorsV2",
363 builder->opts()
364 .WithName(UniqueNodeName("MutableHashTableOfTensors"))
365 .WithAttr("use_node_name_sharing", true)
366 .WithAttr("key_dtype", key_dtype())
367 .WithAttr("value_dtype", value_dtype())
368 .WithAttr("value_shape", value_shape_));
369 Node* keys_node = ops::SourceOp(
370 "Const",
371 builder->opts().WithAttr("dtype", key_dtype()).WithAttr("value", keys));
372 Node* values_node =
373 ops::SourceOp("Const", builder->opts()
374 .WithAttr("dtype", value_dtype())
375 .WithAttr("value", values));
376 Node* import_table =
377 ops::TernaryOp("LookupTableImportV2", table, keys_node, values_node,
378 builder->opts()
379 .WithAttr("Tin", key_dtype())
380 .WithAttr("Tout", value_dtype()));
381 *out = ops::UnaryOp("Identity", table,
382 builder->opts().WithControlInput(import_table));
383 return OkStatus();
384 }
385
386 private:
387 // Writes all keys and values into `keys` and `values`. `keys` and `values`
388 // must point to tensors of size `table_.size()`.
389 void ExportKeysAndValues(Tensor* keys, Tensor* values) const
390 TF_SHARED_LOCKS_REQUIRED(mu_) {
391 int64_t value_dim = value_shape_.dim_size(0);
392 auto keys_data = keys->flat<K>();
393 auto values_data = values->matrix<V>();
394 int64_t i = 0;
395 for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
396 K key = it->first;
397 ValueArray value = it->second;
398 keys_data(i) = key;
399 for (int64_t j = 0; j < value_dim; j++) {
400 values_data(i, j) = value[j];
401 }
402 }
403 }
404
405 TensorShape value_shape_;
406 mutable mutex mu_;
407 typedef gtl::InlinedVector<V, 4> ValueArray;
408 std::unordered_map<K, ValueArray> table_ TF_GUARDED_BY(mu_);
409};
410
411namespace {
412
413template <typename T>
414inline uint64 HashScalar(const T& key) {
415 return static_cast<uint64>(key);
416}
417
418inline uint64 HashScalar(const tstring& key) { return Hash64(key); }
419
420// If the given shape is a scalar return {1} instead. Otherwise leave it alone.
421TensorShape MaybeVectorizeShape(const TensorShape& shape) {
422 if (shape.dims() == 0) {
423 return TensorShape({1});
424 }
425 return shape;
426}
427
428} // namespace
429
430// Modeled after densehashtable in https://github.com/sparsehash/sparsehash
431template <class K, class V>
432class MutableDenseHashTable final : public LookupInterface {
433 public:
434 MutableDenseHashTable(OpKernelContext* ctx, OpKernel* kernel) {
435 OP_REQUIRES_OK(
436 ctx, GetNodeAttr(kernel->def(), "max_load_factor", &max_load_factor_));
437 OP_REQUIRES(ctx, max_load_factor_ > 0 && max_load_factor_ < 1,
438 errors::InvalidArgument(
439 "max_load_factor must be between 0 and 1, got: ",
440 max_load_factor_));
441
442 OP_REQUIRES_OK(ctx,
443 GetNodeAttr(kernel->def(), "value_shape", &value_shape_));
444 OP_REQUIRES(ctx,
445 TensorShapeUtils::IsScalar(value_shape_) ||
446 TensorShapeUtils::IsVector(value_shape_),
447 errors::InvalidArgument(
448 "Empty value must be a scalar or a vector, got shape ",
449 value_shape_.DebugString()));
450
451 const Tensor* empty_key_input;
452 OP_REQUIRES_OK(ctx, ctx->input("empty_key", &empty_key_input));
453 key_shape_ = empty_key_input->shape();
454 OP_REQUIRES(ctx,
455 TensorShapeUtils::IsScalar(key_shape_) ||
456 TensorShapeUtils::IsVector(key_shape_),
457 errors::InvalidArgument(
458 "Empty key must be a scalar or a vector, got shape ",
459 key_shape_.DebugString()));
460 empty_key_ = *empty_key_input;
461 empty_key_hash_ = HashKey(
462 empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}),
463 0);
464
465 const Tensor* deleted_key_input;
466 OP_REQUIRES_OK(ctx, ctx->input("deleted_key", &deleted_key_input));
467 OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()),
468 errors::InvalidArgument(
469 "Empty and deleted keys must have same shape, got shapes: ",
470 key_shape_.DebugString(), " and ",
471 deleted_key_input->shape().DebugString()));
472 deleted_key_ = *deleted_key_input;
473 deleted_key_hash_ = HashKey(deleted_key_input->template shaped<K, 2>(
474 {1, key_shape_.num_elements()}),
475 0);
476
477 if (empty_key_hash_ == deleted_key_hash_) {
478 const int64_t key_size = key_shape_.num_elements();
479 const auto empty_key_matrix =
480 empty_key_.template shaped<K, 2>({1, key_size});
481 const auto deleted_key_matrix =
482 deleted_key_.template shaped<K, 2>({1, key_size});
483 OP_REQUIRES(
484 ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0),
485 errors::InvalidArgument("Empty and deleted keys cannot be equal"));
486 }
487
488 int64_t initial_num_buckets;
489 OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets",
490 &initial_num_buckets));
491 OP_REQUIRES_OK(ctx, AllocateBuckets(ctx, initial_num_buckets));
492 }
493
494 size_t size() const override TF_LOCKS_EXCLUDED(mu_) {
495 tf_shared_lock l(mu_);
496 return num_entries_;
497 }
498
499 Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
500 const Tensor& default_value) override TF_LOCKS_EXCLUDED(mu_) {
501 const int64_t num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
502 const int64_t key_size = key_shape_.num_elements();
503 const int64_t value_size = value_shape_.num_elements();
504 if (key.NumElements() != num_elements * key_size) {
505 TensorShape expected_shape({num_elements});
506 expected_shape.AppendShape(key_shape_);
507 return errors::InvalidArgument("Expected key shape ",
508 expected_shape.DebugString(), " got ",
509 key.shape().DebugString());
510 }
511 const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
512 auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
513 const auto default_flat = default_value.flat<V>();
514
515 tf_shared_lock l(mu_);
516 const auto key_buckets_matrix = key_buckets_.template matrix<K>();
517 const auto value_buckets_matrix = value_buckets_.template matrix<V>();
518 const auto empty_key_matrix =
519 empty_key_.template shaped<K, 2>({1, key_size});
520 const auto deleted_key_matrix =
521 deleted_key_.template shaped<K, 2>({1, key_size});
522 const int64_t bit_mask = num_buckets_ - 1;
523 // TODO(andreasst): parallelize using work_sharder
524 for (int64_t i = 0; i < num_elements; ++i) {
525 const uint64 key_hash = HashKey(key_matrix, i);
526 if (empty_key_hash_ == key_hash &&
527 IsEqualKey(empty_key_matrix, 0, key_matrix, i)) {
528 return errors::InvalidArgument(
529 "Using the empty_key as a table key is not allowed");
530 }
531 if (deleted_key_hash_ == key_hash &&
532 IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) {
533 return errors::InvalidArgument(
534 "Using the deleted_key as a table key is not allowed");
535 }
536 int64_t bucket_index = key_hash & bit_mask;
537 int64_t num_probes = 0;
538 while (true) {
539 if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
540 for (int64_t j = 0; j < value_size; ++j) {
541 // TODO(andreasst): check if we can get rid of SubtleMustCopy
542 // here and elsewhere in this file.
543 value_matrix(i, j) =
544 SubtleMustCopyIfIntegral(value_buckets_matrix(bucket_index, j));
545 }
546 break;
547 }
548 if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_matrix, 0)) {
549 for (int64_t j = 0; j < value_size; ++j) {
550 value_matrix(i, j) = SubtleMustCopyIfIntegral(default_flat(j));
551 }
552 break;
553 }
554 ++num_probes;
555 bucket_index =
556 (bucket_index + num_probes) & bit_mask; // quadratic probing
557 if (num_probes >= num_buckets_) {
558 return errors::Internal(
559 "Internal error in MutableDenseHashTable lookup");
560 }
561 }
562 }
563 return OkStatus();
564 }
565
566 Status Insert(OpKernelContext* ctx, const Tensor& key,
567 const Tensor& value) override TF_LOCKS_EXCLUDED(mu_) {
568 const int64_t batch_size = (key.dims() == 0) ? 1 : key.dim_size(0);
569 if (key.NumElements() != batch_size * key_shape_.num_elements()) {
570 TensorShape expected_shape({batch_size});
571 expected_shape.AppendShape(key_shape_);
572 return errors::InvalidArgument("Expected key shape ",
573 expected_shape.DebugString(), " got ",
574 key.shape().DebugString());
575 }
576 mutex_lock l(mu_);
577 // For simplicity we assume that all keys in the input result in inserts
578 // rather than updates. That means we may grow the table even though we
579 // don't need to. As long as the number of keys inserted in one call is
580 // small compared to the size of the map, the impact of this is minimal.
581 const int64_t pending_num_entries = num_entries_ + batch_size;
582 if (pending_num_entries > num_buckets_ * max_load_factor_) {
583 int64_t new_num_buckets = num_buckets_;
584 do {
585 new_num_buckets <<= 1;
586 } while (pending_num_entries > new_num_buckets * max_load_factor_);
587 TF_RETURN_IF_ERROR(Rebucket(ctx, new_num_buckets));
588 }
589 return DoInsert(ctx, key, value, false);
590 }
591
592 Status Remove(OpKernelContext* ctx, const Tensor& key) override
593 TF_LOCKS_EXCLUDED(mu_) {
594 if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
595 TensorShape expected_shape({key.dim_size(0)});
596 expected_shape.AppendShape(key_shape_);
597 return errors::InvalidArgument("Expected key shape ",
598 expected_shape.DebugString(), " got ",
599 key.shape().DebugString());
600 }
601 mutex_lock l(mu_);
602 return DoRemove(ctx, key);
603 }
604
605 Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
606 const Tensor& values) override TF_LOCKS_EXCLUDED(mu_) {
607 mutex_lock l(mu_);
608 num_buckets_ = keys.dim_size(0);
609 key_buckets_ = keys;
610 value_buckets_ = values;
611 // Count the number of keys that are not the empty_key or deleted_key.
612 // This requires iterating through the whole table but that is OK as we
613 // only execute it during checkpoint restore.
614 num_entries_ = 0;
615 const auto empty_key_tensor =
616 empty_key_.template shaped<K, 2>({1, key_shape_.num_elements()});
617 const auto deleted_key_tensor =
618 deleted_key_.template shaped<K, 2>({1, key_shape_.num_elements()});
619 const auto key_buckets_tensor = key_buckets_.template matrix<K>();
620 for (int64_t i = 0; i < num_buckets_; ++i) {
621 if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) &&
622 !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) {
623 ++num_entries_;
624 }
625 }
626 return OkStatus();
627 }
628
629 Status ExportValues(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) {
630 tf_shared_lock l(mu_);
631 TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_));
632 TF_RETURN_IF_ERROR(ctx->set_output("values", value_buckets_));
633 return OkStatus();
634 }
635
636 Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
637 const Tensor& values) override {
638 TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values));
639 TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape()));
640
641 // The storage format in key_buckets_ and value_buckets_ is always vectors,
642 // even if the inputs are scalars. This is what eventually gets exported
643 // and is expected by the import method as well.
644 TensorShape key_shape = MaybeVectorizeShape(key_shape_);
645 TensorShape value_shape = MaybeVectorizeShape(value_shape_);
646
647 // Compute the final expected shape of the value by starting with the shape
648 // of all keys, removing the dimensions particular to each key and then
649 // appending the shape of a single value.
650 TensorShape expected_value_shape = keys.shape();
651 expected_value_shape.RemoveLastDims(key_shape.dims());
652 expected_value_shape.AppendShape(value_shape);
653 if (values.shape() != expected_value_shape) {
654 return errors::InvalidArgument(
655 "Expected shape ", expected_value_shape.DebugString(),
656 " for value, got ", values.shape().DebugString());
657 }
658 return OkStatus();
659 }
660
661 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
662
663 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
664
665 TensorShape key_shape() const override { return key_shape_; }
666
667 TensorShape value_shape() const override { return value_shape_; }
668
669 int64_t MemoryUsed() const override TF_LOCKS_EXCLUDED(mu_) {
670 tf_shared_lock l(mu_);
671 return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
672 value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
673 }
674
675 private:
676 Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
677 bool ignore_empty_and_deleted_key)
678 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
679 const int64_t num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
680 const int64_t value_size = value_shape_.num_elements();
681 const int64_t key_size = key_shape_.num_elements();
682 const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
683 auto value_matrix = value.shaped<V, 2>({num_elements, value_size});
684
685 auto key_buckets_matrix = key_buckets_.template matrix<K>();
686 auto value_buckets_matrix = value_buckets_.template matrix<V>();
687 const auto empty_key_tensor =
688 empty_key_.template shaped<K, 2>({1, key_size});
689 const auto deleted_key_tensor =
690 deleted_key_.template shaped<K, 2>({1, key_size});
691 const int64_t bit_mask = num_buckets_ - 1;
692 for (int64_t i = 0; i < num_elements; ++i) {
693 const uint64 key_hash = HashKey(key_matrix, i);
694 if (empty_key_hash_ == key_hash &&
695 IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
696 if (ignore_empty_and_deleted_key) {
697 continue;
698 }
699 return errors::InvalidArgument(
700 "Using the empty_key as a table key is not allowed");
701 }
702 if (deleted_key_hash_ == key_hash &&
703 IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
704 if (ignore_empty_and_deleted_key) {
705 continue;
706 }
707 return errors::InvalidArgument(
708 "Using the deleted_key as a table key is not allowed");
709 }
710 int64_t bucket_index = key_hash & bit_mask;
711 int64_t num_probes = 0;
712 while (true) {
713 if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
714 for (int64_t j = 0; j < value_size; ++j) {
715 value_buckets_matrix(bucket_index, j) =
716 SubtleMustCopyIfIntegral(value_matrix(i, j));
717 }
718 break;
719 }
720 if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) ||
721 IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor,
722 0)) {
723 ++num_entries_;
724 for (int64_t j = 0; j < key_size; ++j) {
725 key_buckets_matrix(bucket_index, j) =
726 SubtleMustCopyIfIntegral(key_matrix(i, j));
727 }
728 for (int64_t j = 0; j < value_size; ++j) {
729 value_buckets_matrix(bucket_index, j) =
730 SubtleMustCopyIfIntegral(value_matrix(i, j));
731 }
732 break;
733 }
734 ++num_probes;
735 bucket_index =
736 (bucket_index + num_probes) & bit_mask; // quadratic probing
737 if (num_probes >= num_buckets_) {
738 return errors::Internal(
739 "Internal error in MutableDenseHashTable insert");
740 }
741 }
742 }
743 return OkStatus();
744 }
745
746 Status DoRemove(OpKernelContext* ctx, const Tensor& key)
747 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
748 const int64_t num_elements = key.dim_size(0);
749 const int64_t key_size = key_shape_.num_elements();
750 const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
751
752 auto key_buckets_matrix = key_buckets_.template matrix<K>();
753 const auto empty_key_tensor =
754 empty_key_.template shaped<K, 2>({1, key_size});
755 const auto deleted_key_tensor =
756 deleted_key_.template shaped<K, 2>({1, key_size});
757 const auto deleted_key_flat = deleted_key_.template flat<K>();
758 const int64_t bit_mask = num_buckets_ - 1;
759 for (int64_t i = 0; i < num_elements; ++i) {
760 const uint64 key_hash = HashKey(key_matrix, i);
761 if (empty_key_hash_ == key_hash &&
762 IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
763 return errors::InvalidArgument(
764 "Using the empty_key as a table key is not allowed");
765 }
766 if (deleted_key_hash_ == key_hash &&
767 IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
768 return errors::InvalidArgument(
769 "Using the deleted_key as a table key is not allowed");
770 }
771 int64_t bucket_index = key_hash & bit_mask;
772 int64_t num_probes = 0;
773 while (true) {
774 if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
775 --num_entries_;
776 for (int64_t j = 0; j < key_size; ++j) {
777 key_buckets_matrix(bucket_index, j) =
778 SubtleMustCopyIfIntegral(deleted_key_flat(j));
779 }
780 break;
781 }
782 if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
783 break;
784 }
785 ++num_probes;
786 bucket_index =
787 (bucket_index + num_probes) & bit_mask; // quadratic probing
788 if (num_probes >= num_buckets_) {
789 return errors::Internal(
790 "Internal error in MutableDenseHashTable remove");
791 }
792 }
793 }
794 return OkStatus();
795 }
796
797 Status AllocateBuckets(OpKernelContext* ctx, int64_t new_num_buckets)
798 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
799 if (new_num_buckets < 4 ||
800 ((new_num_buckets & (new_num_buckets - 1)) != 0)) {
801 return errors::InvalidArgument(
802 "Number of buckets must be at least 4 and a power of 2, got: ",
803 new_num_buckets);
804 }
805 num_buckets_ = new_num_buckets;
806 num_entries_ = 0;
807
808 const int64_t key_size = key_shape_.num_elements();
809 TF_RETURN_IF_ERROR(ctx->allocate_temp(
810 key_dtype(), TensorShape({num_buckets_, key_size}), &key_buckets_));
811 auto key_buckets_matrix = key_buckets_.matrix<K>();
812 const auto empty_key_flat = empty_key_.template flat<K>();
813 for (int64_t i = 0; i < num_buckets_; ++i) {
814 for (int64_t j = 0; j < key_size; ++j) {
815 key_buckets_matrix(i, j) = empty_key_flat(j);
816 }
817 }
818
819 const int64_t value_size = value_shape_.num_elements();
820
821 TF_RETURN_IF_ERROR(ctx->allocate_temp(
822 value_dtype(), TensorShape({num_buckets_, value_size}),
823 &value_buckets_));
824 auto value_buckets_matrix = value_buckets_.matrix<V>();
825 for (int64_t i = 0; i < num_buckets_; ++i) {
826 for (int64_t j = 0; j < value_size; ++j) {
827 // Initialize values to the default value for the type to avoid
828 // exposing uninitialized memory in ExportValues().
829 value_buckets_matrix(i, j) = V();
830 }
831 }
832 return OkStatus();
833 }
834
835 Status Rebucket(OpKernelContext* ctx, int64_t num_new_buckets)
836 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
837 Tensor old_key_buckets = key_buckets_;
838 Tensor old_value_buckets = value_buckets_;
839 TF_RETURN_IF_ERROR(AllocateBuckets(ctx, num_new_buckets));
840 return DoInsert(ctx, old_key_buckets, old_value_buckets, true);
841 }
842
843 uint64 HashKey(typename TTypes<K>::ConstMatrix key, int64_t index) const {
844 if (key_shape_.num_elements() == 1) {
845 return HashScalar(key(index, 0));
846 }
847 uint64 result = 0;
848 for (int64_t i = 0; i < key_shape_.num_elements(); ++i) {
849 result = Hash64Combine(result, HashScalar(key(index, i)));
850 }
851 return result;
852 }
853
854 // Use a template to allow this function to be used both with Matrix and
855 // ConstMatrix types.
856 template <typename MT2>
857 bool IsEqualKey(typename TTypes<K>::Matrix tensor1, int64_t index1,
858 MT2 tensor2, int64_t index2) const {
859 for (int64_t i = 0; i < key_shape_.num_elements(); ++i) {
860 if (tensor1(index1, i) != tensor2(index2, i)) {
861 return false;
862 }
863 }
864 return true;
865 }
866
867 TensorShape key_shape_;
868 TensorShape value_shape_;
869 float max_load_factor_;
870 mutable mutex mu_;
871 int64_t num_entries_ TF_GUARDED_BY(mu_);
872 int64_t num_buckets_ TF_GUARDED_BY(mu_);
873 Tensor key_buckets_ TF_GUARDED_BY(mu_);
874 Tensor value_buckets_ TF_GUARDED_BY(mu_);
875 Tensor empty_key_;
876 uint64 empty_key_hash_;
877 Tensor deleted_key_;
878 uint64 deleted_key_hash_;
879};
880
881} // namespace lookup
882
883// Base class for kernels that take a LookupTable handle as the 0th input.
884class LookupTableOpKernel : public OpKernel {
885 public:
886 explicit LookupTableOpKernel(OpKernelConstruction* ctx)
887 : OpKernel(ctx),
888 expected_input_0_(ctx->input_type(0) == DT_RESOURCE ? DT_RESOURCE
889 : DT_STRING_REF) {}
890
891 protected:
892 Status GetTable(OpKernelContext* ctx, lookup::LookupInterface** table) {
893 if (expected_input_0_ == DT_RESOURCE) {
894 return GetResourceLookupTable("table_handle", ctx, table);
895 } else {
896 return GetReferenceLookupTable("table_handle", ctx, table);
897 }
898 }
899
900 // Input 0 could be a STRING_REF or a RESOURCE
901 const DataType expected_input_0_;
902};
903
904// Table lookup op. Perform the lookup operation on the given table.
905class LookupTableFindOp : public LookupTableOpKernel {
906 public:
907 using LookupTableOpKernel::LookupTableOpKernel;
908
909 void Compute(OpKernelContext* ctx) override {
910 lookup::LookupInterface* table;
911 OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
912 core::ScopedUnref unref_me(table);
913
914 DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(),
915 table->value_dtype()};
916 DataTypeVector expected_outputs = {table->value_dtype()};
917 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
918
919 const Tensor& key = ctx->input(1);
920 const Tensor& default_value = ctx->input(2);
921 OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
922
923 TensorShape output_shape = key.shape();
924 output_shape.RemoveLastDims(table->key_shape().dims());
925 output_shape.AppendShape(table->value_shape());
926 Tensor* out;
927 OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
928
929 OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value));
930 }
931};
932
933REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU),
934 LookupTableFindOp);
935REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU),
936 LookupTableFindOp);
937
938// Table insert op.
939class LookupTableInsertOp : public LookupTableOpKernel {
940 public:
941 using LookupTableOpKernel::LookupTableOpKernel;
942
943 void Compute(OpKernelContext* ctx) override {
944 lookup::LookupInterface* table;
945 OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
946 core::ScopedUnref unref_me(table);
947
948 DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(),
949 table->value_dtype()};
950 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
951
952 const Tensor& keys = ctx->input(1);
953 const Tensor& values = ctx->input(2);
954 OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForInsert(keys, values));
955
956 int64_t memory_used_before = 0;
957 if (ctx->track_allocations()) {
958 memory_used_before = table->MemoryUsed();
959 }
960 OP_REQUIRES_OK(ctx, table->Insert(ctx, keys, values));
961 if (ctx->track_allocations()) {
962 ctx->record_persistent_memory_allocation(table->MemoryUsed() -
963 memory_used_before);
964 }
965 }
966};
967
968REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU),
969 LookupTableInsertOp);
970REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU),
971 LookupTableInsertOp);
972
973// Table remove op.
974class LookupTableRemoveOp : public LookupTableOpKernel {
975 public:
976 using LookupTableOpKernel::LookupTableOpKernel;
977
978 void Compute(OpKernelContext* ctx) override {
979 lookup::LookupInterface* table;
980 OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
981 core::ScopedUnref unref_me(table);
982
983 DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype()};
984 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
985
986 const Tensor& key = ctx->input(1);
987 OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key));
988
989 int64_t memory_used_before = 0;
990 if (ctx->track_allocations()) {
991 memory_used_before = table->MemoryUsed();
992 }
993 OP_REQUIRES_OK(ctx, table->Remove(ctx, key));
994 if (ctx->track_allocations()) {
995 ctx->record_persistent_memory_allocation(table->MemoryUsed() -
996 memory_used_before);
997 }
998 }
999};
1000
1001REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2").Device(DEVICE_CPU),
1002 LookupTableRemoveOp);
1003
1004// Op that returns the size of the given table.
1005class LookupTableSizeOp : public LookupTableOpKernel {
1006 public:
1007 using LookupTableOpKernel::LookupTableOpKernel;
1008
1009 void Compute(OpKernelContext* ctx) override {
1010 lookup::LookupInterface* table;
1011 OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
1012 core::ScopedUnref unref_me(table);
1013
1014 Tensor* out;
1015 OP_REQUIRES_OK(ctx, ctx->allocate_output("size", TensorShape({}), &out));
1016 out->flat<int64_t>().setConstant(table->size());
1017 }
1018};
1019
1020REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU),
1021 LookupTableSizeOp);
1022REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2").Device(DEVICE_CPU),
1023 LookupTableSizeOp);
1024
1025// Op that outputs tensors of all keys and all values.
1026class LookupTableExportOp : public LookupTableOpKernel {
1027 public:
1028 using LookupTableOpKernel::LookupTableOpKernel;
1029
1030 void Compute(OpKernelContext* ctx) override {
1031 lookup::LookupInterface* table;
1032 OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
1033 core::ScopedUnref unref_me(table);
1034
1035 OP_REQUIRES_OK(ctx, table->ExportValues(ctx));
1036 }
1037};
1038
1039REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU),
1040 LookupTableExportOp);
1041REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2").Device(DEVICE_CPU),
1042 LookupTableExportOp);
1043
1044// Clear the table and insert data.
1045class LookupTableImportOp : public LookupTableOpKernel {
1046 public:
1047 using LookupTableOpKernel::LookupTableOpKernel;
1048
1049 void Compute(OpKernelContext* ctx) override {
1050 lookup::LookupInterface* table;
1051 OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
1052 core::ScopedUnref unref_me(table);
1053
1054 DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(),
1055 table->value_dtype()};
1056 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
1057
1058 const Tensor& keys = ctx->input(1);
1059 const Tensor& values = ctx->input(2);
1060 OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
1061
1062 int memory_used_before = 0;
1063 if (ctx->track_allocations()) {
1064 memory_used_before = table->MemoryUsed();
1065 }
1066 OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
1067 if (ctx->track_allocations()) {
1068 ctx->record_persistent_memory_allocation(table->MemoryUsed() -
1069 memory_used_before);
1070 }
1071 }
1072};
1073
1074REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU),
1075 LookupTableImportOp);
1076REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
1077 LookupTableImportOp);
1078
1079// Register the HashTable op with the currently supported key and value types.
1080#define REGISTER_KERNEL(key_dtype, value_dtype) \
1081 REGISTER_KERNEL_BUILDER( \
1082 Name("HashTable") \
1083 .Device(DEVICE_CPU) \
1084 .TypeConstraint<key_dtype>("key_dtype") \
1085 .TypeConstraint<value_dtype>("value_dtype"), \
1086 LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
1087 value_dtype>) \
1088 REGISTER_KERNEL_BUILDER( \
1089 Name("HashTableV2") \
1090 .Device(DEVICE_CPU) \
1091 .TypeConstraint<key_dtype>("key_dtype") \
1092 .TypeConstraint<value_dtype>("value_dtype"), \
1093 LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
1094 value_dtype>) \
1095 REGISTER_KERNEL_BUILDER( \
1096 Name("AnonymousHashTable") \
1097 .Device(DEVICE_CPU) \
1098 .TypeConstraint<key_dtype>("key_dtype") \
1099 .TypeConstraint<value_dtype>("value_dtype"), \
1100 AnonymousLookupTableOp<lookup::HashTable<key_dtype, value_dtype>, \
1101 key_dtype, value_dtype>)
1102
1103REGISTER_KERNEL(int32, double);
1104REGISTER_KERNEL(int32, float);
1105REGISTER_KERNEL(int32, int32);
1106REGISTER_KERNEL(int32, tstring);
1107REGISTER_KERNEL(int64_t, double);
1108REGISTER_KERNEL(int64_t, float);
1109REGISTER_KERNEL(int64_t, int32);
1110REGISTER_KERNEL(int64_t, int64_t);
1111REGISTER_KERNEL(int64_t, tstring);
1112REGISTER_KERNEL(tstring, bool);
1113REGISTER_KERNEL(tstring, double);
1114REGISTER_KERNEL(tstring, float);
1115REGISTER_KERNEL(tstring, int32);
1116REGISTER_KERNEL(tstring, int64_t);
1117REGISTER_KERNEL(tstring, tstring);
1118
1119#undef REGISTER_KERNEL
1120
1121// Register the MutableHashTable op.
1122#define REGISTER_KERNEL(key_dtype, value_dtype) \
1123 REGISTER_KERNEL_BUILDER( \
1124 Name("MutableHashTable") \
1125 .Device(DEVICE_CPU) \
1126 .TypeConstraint<key_dtype>("key_dtype") \
1127 .TypeConstraint<value_dtype>("value_dtype"), \
1128 LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
1129 key_dtype, value_dtype>) \
1130 REGISTER_KERNEL_BUILDER( \
1131 Name("MutableHashTableV2") \
1132 .Device(DEVICE_CPU) \
1133 .TypeConstraint<key_dtype>("key_dtype") \
1134 .TypeConstraint<value_dtype>("value_dtype"), \
1135 LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
1136 key_dtype, value_dtype>) \
1137 REGISTER_KERNEL_BUILDER( \
1138 Name("AnonymousMutableHashTable") \
1139 .Device(DEVICE_CPU) \
1140 .TypeConstraint<key_dtype>("key_dtype") \
1141 .TypeConstraint<value_dtype>("value_dtype"), \
1142 AnonymousLookupTableOp< \
1143 lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
1144 key_dtype, value_dtype>)
1145
1146REGISTER_KERNEL(int32, double);
1147REGISTER_KERNEL(int32, float);
1148REGISTER_KERNEL(int32, int32);
1149REGISTER_KERNEL(int64_t, double);
1150REGISTER_KERNEL(int64_t, float);
1151REGISTER_KERNEL(int64_t, int32);
1152REGISTER_KERNEL(int64_t, int64_t);
1153REGISTER_KERNEL(int64_t, tstring);
1154REGISTER_KERNEL(int64_t, Variant);
1155REGISTER_KERNEL(tstring, bool);
1156REGISTER_KERNEL(tstring, double);
1157REGISTER_KERNEL(tstring, float);
1158REGISTER_KERNEL(tstring, int32);
1159REGISTER_KERNEL(tstring, int64_t);
1160
1161#undef REGISTER_KERNEL
1162
1163// Register the MutableHashTableOfTensors op.
1164#define REGISTER_KERNEL(key_dtype, value_dtype) \
1165 REGISTER_KERNEL_BUILDER( \
1166 Name("MutableHashTableOfTensors") \
1167 .Device(DEVICE_CPU) \
1168 .TypeConstraint<key_dtype>("key_dtype") \
1169 .TypeConstraint<value_dtype>("value_dtype"), \
1170 LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
1171 key_dtype, value_dtype>) \
1172 REGISTER_KERNEL_BUILDER( \
1173 Name("MutableHashTableOfTensorsV2") \
1174 .Device(DEVICE_CPU) \
1175 .TypeConstraint<key_dtype>("key_dtype") \
1176 .TypeConstraint<value_dtype>("value_dtype"), \
1177 LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
1178 key_dtype, value_dtype>) \
1179 REGISTER_KERNEL_BUILDER( \
1180 Name("AnonymousMutableHashTableOfTensors") \
1181 .Device(DEVICE_CPU) \
1182 .TypeConstraint<key_dtype>("key_dtype") \
1183 .TypeConstraint<value_dtype>("value_dtype"), \
1184 AnonymousLookupTableOp< \
1185 lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
1186 key_dtype, value_dtype>)
1187
1188REGISTER_KERNEL(int32, double);
1189REGISTER_KERNEL(int32, float);
1190REGISTER_KERNEL(int32, int32);
1191REGISTER_KERNEL(int64_t, double);
1192REGISTER_KERNEL(int64_t, float);
1193REGISTER_KERNEL(int64_t, int32);
1194REGISTER_KERNEL(int64_t, int64_t);
1195REGISTER_KERNEL(int64_t, tstring);
1196REGISTER_KERNEL(tstring, bool);
1197REGISTER_KERNEL(tstring, double);
1198REGISTER_KERNEL(tstring, float);
1199REGISTER_KERNEL(tstring, int32);
1200REGISTER_KERNEL(tstring, int64_t);
1201
1202#undef REGISTER_KERNEL
1203
1204// Register the MutableDenseHashTable op.
1205#define REGISTER_KERNEL(key_dtype, value_dtype) \
1206 REGISTER_KERNEL_BUILDER( \
1207 Name("MutableDenseHashTable") \
1208 .Device(DEVICE_CPU) \
1209 .TypeConstraint<key_dtype>("key_dtype") \
1210 .TypeConstraint<value_dtype>("value_dtype"), \
1211 LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
1212 key_dtype, value_dtype>) \
1213 REGISTER_KERNEL_BUILDER( \
1214 Name("MutableDenseHashTableV2") \
1215 .Device(DEVICE_CPU) \
1216 .TypeConstraint<key_dtype>("key_dtype") \
1217 .TypeConstraint<value_dtype>("value_dtype"), \
1218 LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
1219 key_dtype, value_dtype>) \
1220 REGISTER_KERNEL_BUILDER( \
1221 Name("AnonymousMutableDenseHashTable") \
1222 .Device(DEVICE_CPU) \
1223 .TypeConstraint<key_dtype>("key_dtype") \
1224 .TypeConstraint<value_dtype>("value_dtype"), \
1225 AnonymousLookupTableOp< \
1226 lookup::MutableDenseHashTable<key_dtype, value_dtype>, key_dtype, \
1227 value_dtype>)
1228
1229REGISTER_KERNEL(int32, double);
1230REGISTER_KERNEL(int32, float);
1231REGISTER_KERNEL(int32, int32);
1232REGISTER_KERNEL(int64_t, bool);
1233REGISTER_KERNEL(int64_t, double);
1234REGISTER_KERNEL(int64_t, float);
1235REGISTER_KERNEL(int64_t, int32);
1236REGISTER_KERNEL(int64_t, int64_t);
1237REGISTER_KERNEL(int64_t, Variant);
1238REGISTER_KERNEL(tstring, bool);
1239REGISTER_KERNEL(tstring, double);
1240REGISTER_KERNEL(tstring, float);
1241REGISTER_KERNEL(tstring, int32);
1242REGISTER_KERNEL(tstring, int64_t);
1243REGISTER_KERNEL(tstring, ResourceHandle);
1244
1245#undef REGISTER_KERNEL
1246
1247} // namespace tensorflow
1248