1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_ |
18 | |
19 | #include "absl/container/flat_hash_map.h" |
20 | #include "tensorflow/core/framework/bounds_check.h" |
21 | #include "tensorflow/core/framework/lookup_interface.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/resource_mgr.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/tensor_shape.h" |
26 | #include "tensorflow/core/graph/graph_def_builder.h" |
27 | #include "tensorflow/core/kernels/lookup_util.h" |
28 | #include "tensorflow/core/lib/core/errors.h" |
29 | #include "tensorflow/core/lib/core/status.h" |
30 | #include "tensorflow/core/lib/gtl/map_util.h" |
31 | #include "tensorflow/core/platform/errors.h" |
32 | #include "tensorflow/core/platform/macros.h" |
33 | #include "tensorflow/core/platform/thread_annotations.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | // Lookup table op that supports different table implementations specified by |
38 | // the 'Container' template. Container must be derived from LookupInterface. The |
39 | // key and value are of the templated type "key_dtype" and "value_dtype" |
40 | // respectively. |
41 | template <class Container, class key_dtype, class value_dtype> |
42 | class LookupTableOp : public OpKernel { |
43 | public: |
44 | // ctx is not owned by this class. |
45 | explicit LookupTableOp(OpKernelConstruction* ctx) |
46 | : OpKernel(ctx), table_set_(false) { |
47 | if (ctx->output_type(0) == DT_RESOURCE) { |
48 | OP_REQUIRES_OK(ctx, |
49 | ctx->allocate_temp(tensorflow::DT_RESOURCE, |
50 | tensorflow::TensorShape({}), &table_)); |
51 | } else { |
52 | OP_REQUIRES_OK(ctx, |
53 | ctx->allocate_temp(tensorflow::DT_STRING, |
54 | tensorflow::TensorShape({2}), &table_)); |
55 | } |
56 | OP_REQUIRES_OK( |
57 | ctx, ctx->GetAttr("use_node_name_sharing" , &use_node_name_sharing_)); |
58 | } |
59 | |
60 | // ctx is not owned by this function. |
61 | void Compute(OpKernelContext* ctx) override { |
62 | mutex_lock l(mu_); |
63 | |
64 | if (!table_set_) { |
65 | OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), |
66 | use_node_name_sharing_)); |
67 | } |
68 | |
69 | auto creator = |
70 | [ctx, this](lookup::LookupInterface** ret) |
71 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
72 | lookup::LookupInterface* container = new Container(ctx, this); |
73 | if (!ctx->status().ok()) { |
74 | container->Unref(); |
75 | return ctx->status(); |
76 | } |
77 | if (ctx->track_allocations()) { |
78 | ctx->record_persistent_memory_allocation( |
79 | container->MemoryUsed() + table_.AllocatedBytes()); |
80 | } |
81 | *ret = container; |
82 | return OkStatus(); |
83 | }; |
84 | |
85 | lookup::LookupInterface* table = nullptr; |
86 | OP_REQUIRES_OK(ctx, |
87 | cinfo_.resource_manager() |
88 | ->template LookupOrCreate<lookup::LookupInterface>( |
89 | cinfo_.container(), cinfo_.name(), &table, creator)); |
90 | core::ScopedUnref unref_me(table); |
91 | |
92 | OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes( |
93 | *table, DataTypeToEnum<key_dtype>::v(), |
94 | DataTypeToEnum<value_dtype>::v(), cinfo_.name())); |
95 | |
96 | if (ctx->expected_output_dtype(0) == DT_RESOURCE) { |
97 | if (!table_set_) { |
98 | auto h = table_.template scalar<ResourceHandle>(); |
99 | h() = MakeResourceHandle<lookup::LookupInterface>( |
100 | ctx, cinfo_.container(), cinfo_.name()); |
101 | } |
102 | ctx->set_output(0, table_); |
103 | } else { |
104 | if (!table_set_) { |
105 | auto h = table_.template flat<tstring>(); |
106 | h(0) = cinfo_.container(); |
107 | h(1) = cinfo_.name(); |
108 | } |
109 | ctx->set_output_ref(0, &mu_, &table_); |
110 | } |
111 | table_set_ = true; |
112 | } |
113 | |
114 | ~LookupTableOp() override { |
115 | // If the table object was not shared, delete it. |
116 | if (table_set_ && cinfo_.resource_is_private_to_kernel()) { |
117 | if (!cinfo_.resource_manager() |
118 | ->template Delete<lookup::LookupInterface>(cinfo_.container(), |
119 | cinfo_.name()) |
120 | .ok()) { |
121 | // Do nothing; the resource can have been deleted by session resets. |
122 | } |
123 | } |
124 | } |
125 | |
126 | private: |
127 | mutex mu_; |
128 | Tensor table_ TF_GUARDED_BY(mu_); |
129 | bool table_set_ TF_GUARDED_BY(mu_); |
130 | ContainerInfo cinfo_; |
131 | bool use_node_name_sharing_; |
132 | |
133 | TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp); |
134 | }; |
135 | |
136 | // An anonymous version of LookupTableOp, which creates a new table resource |
137 | // everytime `Compute` is called. The resource can only be accessed by the |
138 | // returned resource handle (e.g. it can't be looked up by a name in a resource |
139 | // manager). The resource will be automatically deleted when all resource |
140 | // handles pointing to it are gone. |
141 | template <class Container, class key_dtype, class value_dtype> |
142 | class AnonymousLookupTableOp : public OpKernel { |
143 | public: |
144 | explicit AnonymousLookupTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
145 | |
146 | void Compute(OpKernelContext* ctx) override { |
147 | lookup::LookupInterface* table = new Container(ctx, this); |
148 | if (!ctx->status().ok()) { |
149 | table->Unref(); |
150 | return; |
151 | } |
152 | Tensor table_tensor; |
153 | OP_REQUIRES_OK( |
154 | ctx, ctx->allocate_temp(tensorflow::DT_RESOURCE, |
155 | tensorflow::TensorShape({}), &table_tensor)); |
156 | if (ctx->track_allocations()) { |
157 | ctx->record_persistent_memory_allocation(table->MemoryUsed() + |
158 | table_tensor.AllocatedBytes()); |
159 | } |
160 | table_tensor.scalar<ResourceHandle>()() = |
161 | ResourceHandle::MakeRefCountingHandle<lookup::LookupInterface>( |
162 | table, ctx->device()->name()); |
163 | ctx->set_output(0, table_tensor); |
164 | } |
165 | |
166 | private: |
167 | TF_DISALLOW_COPY_AND_ASSIGN(AnonymousLookupTableOp); |
168 | }; |
169 | |
170 | namespace lookup { |
171 | |
172 | // Ensure that the compiler cannot elide a copy into a local, for |
173 | // bounds checking on source tensors that might be updated asynchronously for |
174 | // integral types. However non-integer variables are not allowed and therefore |
175 | // the local copy is unnecessary. |
176 | template <typename T> |
177 | T SubtleMustCopyIfIntegral(const T& value) { |
178 | return internal::SubtleMustCopy(value); |
179 | } |
180 | |
181 | inline const tstring& SubtleMustCopyIfIntegral(const tstring& value) { |
182 | return value; |
183 | } |
184 | |
185 | inline const float SubtleMustCopyIfIntegral(const float value) { return value; } |
186 | |
187 | inline const double SubtleMustCopyIfIntegral(const double value) { |
188 | return value; |
189 | } |
190 | |
191 | inline const Variant& SubtleMustCopyIfIntegral(const Variant& value) { |
192 | return value; |
193 | } |
194 | |
195 | inline const ResourceHandle& SubtleMustCopyIfIntegral( |
196 | const ResourceHandle& value) { |
197 | return value; |
198 | } |
199 | |
200 | // Returns a unique node name starting with "base". |
201 | std::string UniqueNodeName(const std::string& base); |
202 | |
203 | // Lookup table that wraps an flat_hash_map, where the key and value data type |
204 | // is specified. |
205 | // |
206 | // This table is recommended for any variations to key values. |
207 | // |
208 | // For look up, the table is required to be initialized (allocated |
209 | // and populated). Once the table is marked as initialized it becomes read-only. |
210 | // |
211 | // Sample use case: |
212 | // |
213 | // HashTable<int64, int64> table; // int64 -> int64. |
214 | // table.Initialize(...); |
215 | // table.Find(in_t, &out_t, default_t) |
216 | // |
217 | template <class K, class V> |
218 | class HashTable : public InitializableLookupTable { |
219 | public: |
220 | HashTable(OpKernelContext* ctx, OpKernel* kernel) {} |
221 | |
222 | Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { |
223 | // We set use_node_name_sharing with a unique node name so that the resource |
224 | // can outlive the HashTableV2 kernel. This means that the lifetime of the |
225 | // HashTable resource will be tied to the lifetime of the resource manager |
226 | // it is created in. |
227 | // TODO(b/181695913): Provide a mechanism for deleting this resource |
228 | // earlier when appropriate. |
229 | Node* hash_table_node = ops::SourceOp( |
230 | "HashTableV2" , builder->opts() |
231 | .WithName(UniqueNodeName("HashTableFromGraphDef" )) |
232 | .WithAttr("key_dtype" , key_dtype()) |
233 | .WithAttr("value_dtype" , value_dtype()) |
234 | .WithAttr("use_node_name_sharing" , true)); |
235 | if (table_.empty()) { |
236 | *out = hash_table_node; |
237 | return OkStatus(); |
238 | } |
239 | |
240 | if (initializer_serializer_ == nullptr) { |
241 | std::string message = |
242 | "Failed to serialize lookup table: no initialization function was " |
243 | "specified. Falling back to serializing a handle to the table." ; |
244 | LOG(WARNING) << message; |
245 | return errors::Unimplemented(message); |
246 | } |
247 | Node* initializer; |
248 | TF_RETURN_IF_ERROR(initializer_serializer_->AsGraphDef( |
249 | builder, hash_table_node, &initializer)); |
250 | *out = ops::UnaryOp("Identity" , hash_table_node, |
251 | builder->opts().WithControlInput(initializer)); |
252 | return OkStatus(); |
253 | } |
254 | |
255 | size_t size() const override { |
256 | if (!is_initialized()) |
257 | return 0; |
258 | else |
259 | return table_.size(); |
260 | } |
261 | |
262 | Status ExportValues(OpKernelContext* context) override { |
263 | if (!is_initialized()) { |
264 | return errors::Aborted("HashTable is not initialized." ); |
265 | } |
266 | |
267 | const int64_t size = table_.size(); |
268 | |
269 | Tensor* keys; |
270 | Tensor* values; |
271 | TF_RETURN_IF_ERROR( |
272 | context->allocate_output("keys" , TensorShape({size}), &keys)); |
273 | TF_RETURN_IF_ERROR( |
274 | context->allocate_output("values" , TensorShape({size}), &values)); |
275 | |
276 | auto keys_data = keys->flat<K>(); |
277 | auto values_data = values->flat<V>(); |
278 | int64_t i = 0; |
279 | for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { |
280 | keys_data(i) = it->first; |
281 | values_data(i) = it->second; |
282 | } |
283 | return OkStatus(); |
284 | } |
285 | |
286 | DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } |
287 | |
288 | DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } |
289 | |
290 | protected: |
291 | Status DoPrepare(size_t size) override { |
292 | if (is_initialized()) { |
293 | return errors::Aborted("HashTable already initialized." ); |
294 | } |
295 | if (size > 0) { |
296 | table_.reserve(size); |
297 | } |
298 | return OkStatus(); |
299 | }; |
300 | |
301 | Status DoLazyPrepare(std::function<int64(void)> size_fn) override { |
302 | return DoPrepare(size_fn()); |
303 | } |
304 | |
305 | Status DoInsert(const Tensor& keys, const Tensor& values) override { |
306 | const auto key_values = keys.flat<K>(); |
307 | const auto value_values = values.flat<V>(); |
308 | for (int64_t i = 0; i < key_values.size(); ++i) { |
309 | auto&& key = SubtleMustCopyIfIntegral(key_values(i)); |
310 | auto&& value = SubtleMustCopyIfIntegral(value_values(i)); |
311 | auto result = table_.try_emplace(key, value); |
312 | if (!result.second && result.first->second != value) { |
313 | return errors::FailedPrecondition( |
314 | "HashTable has different value for same key. Key " , key, " has " , |
315 | result.first->second, " and trying to add value " , value); |
316 | } |
317 | } |
318 | return OkStatus(); |
319 | } |
320 | |
321 | Status DoFind(const Tensor& key, Tensor* value, |
322 | const Tensor& default_value) override { |
323 | const V default_val = default_value.flat<V>()(0); |
324 | const auto key_values = key.flat<K>(); |
325 | auto value_values = value->flat<V>(); |
326 | |
327 | for (int64_t i = 0; i < key_values.size(); ++i) { |
328 | value_values(i) = gtl::FindWithDefault( |
329 | table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); |
330 | } |
331 | return OkStatus(); |
332 | } |
333 | |
334 | int64_t MemoryUsed() const override { |
335 | if (!is_initialized()) { |
336 | return 0; |
337 | } |
338 | const int64_t num_elements = table_.size(); |
339 | return num_elements * (sizeof(K) + sizeof(V)); |
340 | } |
341 | |
342 | private: |
343 | absl::flat_hash_map<K, V> table_; |
344 | }; |
345 | |
346 | } // namespace lookup |
347 | |
348 | } // namespace tensorflow |
349 | |
350 | #endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_ |
351 | |