1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_ |
18 | |
19 | #include "tensorflow/core/framework/lookup_interface.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/kernels/initializable_lookup_table.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace data { |
25 | class DatasetBase; |
26 | } // namespace data |
27 | } // namespace tensorflow |
28 | |
29 | namespace tensorflow { |
30 | namespace lookup { |
31 | |
32 | // Gets the LookupTable stored in the ctx->resource_manager() with key |
33 | // passed by attribute with name input_name, returns null if the table |
34 | // doesn't exist. Use GetResourceLookupTable() or GetReferenceLookupTable() if |
35 | // the input dtype is known. |
36 | Status GetLookupTable(StringPiece input_name, OpKernelContext* ctx, |
37 | LookupInterface** table); |
38 | Status GetResourceLookupTable(StringPiece input_name, OpKernelContext* ctx, |
39 | LookupInterface** table); |
40 | Status GetReferenceLookupTable(StringPiece input_name, OpKernelContext* ctx, |
41 | LookupInterface** table); |
42 | |
43 | // Gets the InitializableLookupTable stored in the |
44 | // ctx->resource_manager() with key passed by attribute with name |
45 | // input_name, returns null if the table doesn't exist. |
46 | Status GetInitializableLookupTable(StringPiece input_name, OpKernelContext* ctx, |
47 | InitializableLookupTable** table); |
48 | |
49 | // Verify that the given key_dtype and value_dtype matches the corresponding |
50 | // table's data types. |
51 | Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, |
52 | DataType value_dtype, const string& table_name); |
53 | |
54 | // Initializes `table` from `filename`. |
55 | Status InitializeTableFromTextFile(const string& filename, int64_t vocab_size, |
56 | char delimiter, int32_t key_index, |
57 | int32_t value_index, int64_t offset, |
58 | Env* env, InitializableLookupTable* table); |
59 | |
60 | // Initializes `table` from `filename`. `func` may specify how to represent the |
61 | // initializer as a graphdef, so that the table can be serialized as metadata. |
62 | Status InitializeTableFromTextFile( |
63 | const string& filename, int64_t vocab_size, char delimiter, |
64 | int32_t key_index, int32_t value_index, int64_t offset, Env* env, |
65 | std::unique_ptr<InitializableLookupTable::InitializerSerializer> serializer, |
66 | InitializableLookupTable* table); |
67 | |
68 | } // namespace lookup |
69 | } // namespace tensorflow |
70 | |
71 | #endif // TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_ |
72 | |