1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
17#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
18
19#include "absl/container/flat_hash_map.h"
20#include "tensorflow/core/framework/bounds_check.h"
21#include "tensorflow/core/framework/lookup_interface.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/resource_mgr.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/graph/graph_def_builder.h"
27#include "tensorflow/core/kernels/lookup_util.h"
28#include "tensorflow/core/lib/core/errors.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/lib/gtl/map_util.h"
31#include "tensorflow/core/platform/errors.h"
32#include "tensorflow/core/platform/macros.h"
33#include "tensorflow/core/platform/thread_annotations.h"
34
35namespace tensorflow {
36
37// Lookup table op that supports different table implementations specified by
38// the 'Container' template. Container must be derived from LookupInterface. The
39// key and value are of the templated type "key_dtype" and "value_dtype"
40// respectively.
41template <class Container, class key_dtype, class value_dtype>
42class LookupTableOp : public OpKernel {
43 public:
44 // ctx is not owned by this class.
45 explicit LookupTableOp(OpKernelConstruction* ctx)
46 : OpKernel(ctx), table_set_(false) {
47 if (ctx->output_type(0) == DT_RESOURCE) {
48 OP_REQUIRES_OK(ctx,
49 ctx->allocate_temp(tensorflow::DT_RESOURCE,
50 tensorflow::TensorShape({}), &table_));
51 } else {
52 OP_REQUIRES_OK(ctx,
53 ctx->allocate_temp(tensorflow::DT_STRING,
54 tensorflow::TensorShape({2}), &table_));
55 }
56 OP_REQUIRES_OK(
57 ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_));
58 }
59
60 // ctx is not owned by this function.
61 void Compute(OpKernelContext* ctx) override {
62 mutex_lock l(mu_);
63
64 if (!table_set_) {
65 OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
66 use_node_name_sharing_));
67 }
68
69 auto creator =
70 [ctx, this](lookup::LookupInterface** ret)
71 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
72 lookup::LookupInterface* container = new Container(ctx, this);
73 if (!ctx->status().ok()) {
74 container->Unref();
75 return ctx->status();
76 }
77 if (ctx->track_allocations()) {
78 ctx->record_persistent_memory_allocation(
79 container->MemoryUsed() + table_.AllocatedBytes());
80 }
81 *ret = container;
82 return OkStatus();
83 };
84
85 lookup::LookupInterface* table = nullptr;
86 OP_REQUIRES_OK(ctx,
87 cinfo_.resource_manager()
88 ->template LookupOrCreate<lookup::LookupInterface>(
89 cinfo_.container(), cinfo_.name(), &table, creator));
90 core::ScopedUnref unref_me(table);
91
92 OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
93 *table, DataTypeToEnum<key_dtype>::v(),
94 DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
95
96 if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
97 if (!table_set_) {
98 auto h = table_.template scalar<ResourceHandle>();
99 h() = MakeResourceHandle<lookup::LookupInterface>(
100 ctx, cinfo_.container(), cinfo_.name());
101 }
102 ctx->set_output(0, table_);
103 } else {
104 if (!table_set_) {
105 auto h = table_.template flat<tstring>();
106 h(0) = cinfo_.container();
107 h(1) = cinfo_.name();
108 }
109 ctx->set_output_ref(0, &mu_, &table_);
110 }
111 table_set_ = true;
112 }
113
114 ~LookupTableOp() override {
115 // If the table object was not shared, delete it.
116 if (table_set_ && cinfo_.resource_is_private_to_kernel()) {
117 if (!cinfo_.resource_manager()
118 ->template Delete<lookup::LookupInterface>(cinfo_.container(),
119 cinfo_.name())
120 .ok()) {
121 // Do nothing; the resource can have been deleted by session resets.
122 }
123 }
124 }
125
126 private:
127 mutex mu_;
128 Tensor table_ TF_GUARDED_BY(mu_);
129 bool table_set_ TF_GUARDED_BY(mu_);
130 ContainerInfo cinfo_;
131 bool use_node_name_sharing_;
132
133 TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp);
134};
135
136// An anonymous version of LookupTableOp, which creates a new table resource
137// everytime `Compute` is called. The resource can only be accessed by the
138// returned resource handle (e.g. it can't be looked up by a name in a resource
139// manager). The resource will be automatically deleted when all resource
140// handles pointing to it are gone.
141template <class Container, class key_dtype, class value_dtype>
142class AnonymousLookupTableOp : public OpKernel {
143 public:
144 explicit AnonymousLookupTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
145
146 void Compute(OpKernelContext* ctx) override {
147 lookup::LookupInterface* table = new Container(ctx, this);
148 if (!ctx->status().ok()) {
149 table->Unref();
150 return;
151 }
152 Tensor table_tensor;
153 OP_REQUIRES_OK(
154 ctx, ctx->allocate_temp(tensorflow::DT_RESOURCE,
155 tensorflow::TensorShape({}), &table_tensor));
156 if (ctx->track_allocations()) {
157 ctx->record_persistent_memory_allocation(table->MemoryUsed() +
158 table_tensor.AllocatedBytes());
159 }
160 table_tensor.scalar<ResourceHandle>()() =
161 ResourceHandle::MakeRefCountingHandle<lookup::LookupInterface>(
162 table, ctx->device()->name());
163 ctx->set_output(0, table_tensor);
164 }
165
166 private:
167 TF_DISALLOW_COPY_AND_ASSIGN(AnonymousLookupTableOp);
168};
169
170namespace lookup {
171
172// Ensure that the compiler cannot elide a copy into a local, for
173// bounds checking on source tensors that might be updated asynchronously for
174// integral types. However non-integer variables are not allowed and therefore
175// the local copy is unnecessary.
176template <typename T>
177T SubtleMustCopyIfIntegral(const T& value) {
178 return internal::SubtleMustCopy(value);
179}
180
181inline const tstring& SubtleMustCopyIfIntegral(const tstring& value) {
182 return value;
183}
184
185inline const float SubtleMustCopyIfIntegral(const float value) { return value; }
186
187inline const double SubtleMustCopyIfIntegral(const double value) {
188 return value;
189}
190
191inline const Variant& SubtleMustCopyIfIntegral(const Variant& value) {
192 return value;
193}
194
195inline const ResourceHandle& SubtleMustCopyIfIntegral(
196 const ResourceHandle& value) {
197 return value;
198}
199
200// Returns a unique node name starting with "base".
201std::string UniqueNodeName(const std::string& base);
202
203// Lookup table that wraps an flat_hash_map, where the key and value data type
204// is specified.
205//
206// This table is recommended for any variations to key values.
207//
208// For look up, the table is required to be initialized (allocated
209// and populated). Once the table is marked as initialized it becomes read-only.
210//
211// Sample use case:
212//
213// HashTable<int64, int64> table; // int64 -> int64.
214// table.Initialize(...);
215// table.Find(in_t, &out_t, default_t)
216//
217template <class K, class V>
218class HashTable : public InitializableLookupTable {
219 public:
220 HashTable(OpKernelContext* ctx, OpKernel* kernel) {}
221
222 Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override {
223 // We set use_node_name_sharing with a unique node name so that the resource
224 // can outlive the HashTableV2 kernel. This means that the lifetime of the
225 // HashTable resource will be tied to the lifetime of the resource manager
226 // it is created in.
227 // TODO(b/181695913): Provide a mechanism for deleting this resource
228 // earlier when appropriate.
229 Node* hash_table_node = ops::SourceOp(
230 "HashTableV2", builder->opts()
231 .WithName(UniqueNodeName("HashTableFromGraphDef"))
232 .WithAttr("key_dtype", key_dtype())
233 .WithAttr("value_dtype", value_dtype())
234 .WithAttr("use_node_name_sharing", true));
235 if (table_.empty()) {
236 *out = hash_table_node;
237 return OkStatus();
238 }
239
240 if (initializer_serializer_ == nullptr) {
241 std::string message =
242 "Failed to serialize lookup table: no initialization function was "
243 "specified. Falling back to serializing a handle to the table.";
244 LOG(WARNING) << message;
245 return errors::Unimplemented(message);
246 }
247 Node* initializer;
248 TF_RETURN_IF_ERROR(initializer_serializer_->AsGraphDef(
249 builder, hash_table_node, &initializer));
250 *out = ops::UnaryOp("Identity", hash_table_node,
251 builder->opts().WithControlInput(initializer));
252 return OkStatus();
253 }
254
255 size_t size() const override {
256 if (!is_initialized())
257 return 0;
258 else
259 return table_.size();
260 }
261
262 Status ExportValues(OpKernelContext* context) override {
263 if (!is_initialized()) {
264 return errors::Aborted("HashTable is not initialized.");
265 }
266
267 const int64_t size = table_.size();
268
269 Tensor* keys;
270 Tensor* values;
271 TF_RETURN_IF_ERROR(
272 context->allocate_output("keys", TensorShape({size}), &keys));
273 TF_RETURN_IF_ERROR(
274 context->allocate_output("values", TensorShape({size}), &values));
275
276 auto keys_data = keys->flat<K>();
277 auto values_data = values->flat<V>();
278 int64_t i = 0;
279 for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
280 keys_data(i) = it->first;
281 values_data(i) = it->second;
282 }
283 return OkStatus();
284 }
285
286 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
287
288 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
289
290 protected:
291 Status DoPrepare(size_t size) override {
292 if (is_initialized()) {
293 return errors::Aborted("HashTable already initialized.");
294 }
295 if (size > 0) {
296 table_.reserve(size);
297 }
298 return OkStatus();
299 };
300
301 Status DoLazyPrepare(std::function<int64(void)> size_fn) override {
302 return DoPrepare(size_fn());
303 }
304
305 Status DoInsert(const Tensor& keys, const Tensor& values) override {
306 const auto key_values = keys.flat<K>();
307 const auto value_values = values.flat<V>();
308 for (int64_t i = 0; i < key_values.size(); ++i) {
309 auto&& key = SubtleMustCopyIfIntegral(key_values(i));
310 auto&& value = SubtleMustCopyIfIntegral(value_values(i));
311 auto result = table_.try_emplace(key, value);
312 if (!result.second && result.first->second != value) {
313 return errors::FailedPrecondition(
314 "HashTable has different value for same key. Key ", key, " has ",
315 result.first->second, " and trying to add value ", value);
316 }
317 }
318 return OkStatus();
319 }
320
321 Status DoFind(const Tensor& key, Tensor* value,
322 const Tensor& default_value) override {
323 const V default_val = default_value.flat<V>()(0);
324 const auto key_values = key.flat<K>();
325 auto value_values = value->flat<V>();
326
327 for (int64_t i = 0; i < key_values.size(); ++i) {
328 value_values(i) = gtl::FindWithDefault(
329 table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
330 }
331 return OkStatus();
332 }
333
334 int64_t MemoryUsed() const override {
335 if (!is_initialized()) {
336 return 0;
337 }
338 const int64_t num_elements = table_.size();
339 return num_elements * (sizeof(K) + sizeof(V));
340 }
341
342 private:
343 absl::flat_hash_map<K, V> table_;
344};
345
346} // namespace lookup
347
348} // namespace tensorflow
349
350#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
351