1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_ |
18 | |
19 | #include "tensorflow/core/framework/resource_mgr.h" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/lib/core/status.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | class OpKernelContext; |
26 | |
27 | namespace lookup { |
28 | |
29 | // Forward declaration so we can define GetInitializableLookupTable() in |
30 | // LookupInterface. |
31 | class InitializableLookupTable; |
32 | |
33 | // Lookup interface for batch lookups used by table lookup ops. |
34 | class LookupInterface : public ResourceBase { |
35 | public: |
36 | // Performs batch lookups, for every element in the key tensor, Find returns |
37 | // the corresponding value into the values tensor. |
38 | // If an element is not present in the table, the given default value is used. |
39 | |
40 | // For tables that require initialization, Find is available once the table |
41 | // is marked as initialized. |
42 | |
43 | // Returns the following statuses: |
44 | // - OK: when the find finishes successfully. |
45 | // - FailedPrecondition: if the table is not initialized. |
46 | // - InvalidArgument: if any of the preconditions on the lookup key or value |
47 | // fails. |
48 | // - In addition, other implementations may provide another non-OK status |
49 | // specific to their failure modes. |
50 | virtual Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, |
51 | const Tensor& default_value) = 0; |
52 | |
53 | // Inserts elements into the table. Each element of the key tensor is |
54 | // associated with the corresponding element in the value tensor. |
55 | // This method is only implemented in mutable tables that can be updated over |
56 | // the execution of the graph. It returns Status::NotImplemented for read-only |
57 | // tables that are initialized once before they can be looked up. |
58 | |
59 | // Returns the following statuses: |
60 | // - OK: when the insert finishes successfully. |
61 | // - InvalidArgument: if any of the preconditions on the lookup key or value |
62 | // fails. |
63 | // - Unimplemented: if the table does not support insertions. |
64 | virtual Status Insert(OpKernelContext* ctx, const Tensor& keys, |
65 | const Tensor& values) = 0; |
66 | |
67 | // Removes elements from the table. |
68 | // This method is only implemented in mutable tables that can be updated over |
69 | // the execution of the graph. It returns Status::NotImplemented for read-only |
70 | // tables that are initialized once before they can be looked up. |
71 | |
72 | // Returns the following statuses: |
73 | // - OK: when the remove finishes successfully. |
74 | // - InvalidArgument: if any of the preconditions on the lookup key fails. |
75 | // - Unimplemented: if the table does not support removals. |
76 | virtual Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0; |
77 | |
78 | // Returns the number of elements in the table. |
79 | virtual size_t size() const = 0; |
80 | |
81 | // Exports the values of the table to two tensors named keys and values. |
82 | // Note that the shape of the tensors is completely up to the implementation |
83 | // of the table and can be different than the tensors used for the Insert |
84 | // function above. |
85 | virtual Status ExportValues(OpKernelContext* ctx) = 0; |
86 | |
87 | // Imports previously exported keys and values. |
88 | // As mentioned above, the shape of the keys and values tensors are determined |
89 | // by the ExportValues function above and can be different than for the |
90 | // Insert function. |
91 | virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys, |
92 | const Tensor& values) = 0; |
93 | |
94 | // Returns the data type of the key. |
95 | virtual DataType key_dtype() const = 0; |
96 | |
97 | // Returns the data type of the value. |
98 | virtual DataType value_dtype() const = 0; |
99 | |
100 | // Returns the shape of a key in the table. |
101 | virtual TensorShape key_shape() const = 0; |
102 | |
103 | // Returns the shape of a value in the table. |
104 | virtual TensorShape value_shape() const = 0; |
105 | |
106 | // Check format of the key and value tensors for the Insert function. |
107 | // Returns OK if all the following requirements are satisfied, otherwise it |
108 | // returns InvalidArgument: |
109 | // - DataType of the tensor keys equals to the table key_dtype |
110 | // - DataType of the tensor values equals to the table value_dtype |
111 | // - the values tensor has the required shape given keys and the tables's |
112 | // value shape. |
113 | virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys, |
114 | const Tensor& values); |
115 | |
116 | // Similar to the function above but instead checks eligibility for the Import |
117 | // function. |
118 | virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys, |
119 | const Tensor& values); |
120 | |
121 | // Check format of the key tensor for the Remove function. |
122 | // Returns OK if all the following requirements are satisfied, otherwise it |
123 | // returns InvalidArgument: |
124 | // - DataType of the tensor keys equals to the table key_dtype |
125 | virtual Status CheckKeyTensorForRemove(const Tensor& keys); |
126 | |
127 | // Check the arguments of a find operation. Returns OK if all the following |
128 | // requirements are satisfied, otherwise it returns InvalidArgument: |
129 | // - DataType of the tensor keys equals to the table key_dtype |
130 | // - DataType of the tensor default_value equals to the table value_dtype |
131 | // - the default_value tensor has the required shape given keys and the |
132 | // tables's value shape. |
133 | Status CheckFindArguments(const Tensor& keys, const Tensor& default_value); |
134 | |
135 | string DebugString() const override { |
136 | return strings::StrCat("A lookup table of size: " , size()); |
137 | } |
138 | |
139 | // Returns an InitializableLookupTable, a subclass of LookupInterface, if the |
140 | // current object is an InitializableLookupTable. Otherwise, returns nullptr. |
141 | virtual InitializableLookupTable* GetInitializableLookupTable() { |
142 | return nullptr; |
143 | } |
144 | |
145 | protected: |
146 | virtual ~LookupInterface() = default; |
147 | |
148 | // Makes sure that the key and value tensor DataType's match the table |
149 | // key_dtype and value_dtype. |
150 | Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values); |
151 | |
152 | // Makes sure that the provided shape is consistent with the table keys shape. |
153 | Status CheckKeyShape(const TensorShape& shape); |
154 | |
155 | private: |
156 | Status CheckKeyAndValueTensorsHelper(const Tensor& keys, |
157 | const Tensor& values); |
158 | }; |
159 | |
160 | } // namespace lookup |
161 | } // namespace tensorflow |
162 | |
163 | #endif // TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_ |
164 | |