1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
17#define TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
18
19#include "tensorflow/core/framework/resource_mgr.h"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/lib/core/status.h"
22
23namespace tensorflow {
24
25class OpKernelContext;
26
27namespace lookup {
28
29// Forward declaration so we can define GetInitializableLookupTable() in
30// LookupInterface.
31class InitializableLookupTable;
32
33// Lookup interface for batch lookups used by table lookup ops.
34class LookupInterface : public ResourceBase {
35 public:
36 // Performs batch lookups, for every element in the key tensor, Find returns
37 // the corresponding value into the values tensor.
38 // If an element is not present in the table, the given default value is used.
39
40 // For tables that require initialization, Find is available once the table
41 // is marked as initialized.
42
43 // Returns the following statuses:
44 // - OK: when the find finishes successfully.
45 // - FailedPrecondition: if the table is not initialized.
46 // - InvalidArgument: if any of the preconditions on the lookup key or value
47 // fails.
48 // - In addition, other implementations may provide another non-OK status
49 // specific to their failure modes.
50 virtual Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
51 const Tensor& default_value) = 0;
52
53 // Inserts elements into the table. Each element of the key tensor is
54 // associated with the corresponding element in the value tensor.
55 // This method is only implemented in mutable tables that can be updated over
56 // the execution of the graph. It returns Status::NotImplemented for read-only
57 // tables that are initialized once before they can be looked up.
58
59 // Returns the following statuses:
60 // - OK: when the insert finishes successfully.
61 // - InvalidArgument: if any of the preconditions on the lookup key or value
62 // fails.
63 // - Unimplemented: if the table does not support insertions.
64 virtual Status Insert(OpKernelContext* ctx, const Tensor& keys,
65 const Tensor& values) = 0;
66
67 // Removes elements from the table.
68 // This method is only implemented in mutable tables that can be updated over
69 // the execution of the graph. It returns Status::NotImplemented for read-only
70 // tables that are initialized once before they can be looked up.
71
72 // Returns the following statuses:
73 // - OK: when the remove finishes successfully.
74 // - InvalidArgument: if any of the preconditions on the lookup key fails.
75 // - Unimplemented: if the table does not support removals.
76 virtual Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0;
77
78 // Returns the number of elements in the table.
79 virtual size_t size() const = 0;
80
81 // Exports the values of the table to two tensors named keys and values.
82 // Note that the shape of the tensors is completely up to the implementation
83 // of the table and can be different than the tensors used for the Insert
84 // function above.
85 virtual Status ExportValues(OpKernelContext* ctx) = 0;
86
87 // Imports previously exported keys and values.
88 // As mentioned above, the shape of the keys and values tensors are determined
89 // by the ExportValues function above and can be different than for the
90 // Insert function.
91 virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
92 const Tensor& values) = 0;
93
94 // Returns the data type of the key.
95 virtual DataType key_dtype() const = 0;
96
97 // Returns the data type of the value.
98 virtual DataType value_dtype() const = 0;
99
100 // Returns the shape of a key in the table.
101 virtual TensorShape key_shape() const = 0;
102
103 // Returns the shape of a value in the table.
104 virtual TensorShape value_shape() const = 0;
105
106 // Check format of the key and value tensors for the Insert function.
107 // Returns OK if all the following requirements are satisfied, otherwise it
108 // returns InvalidArgument:
109 // - DataType of the tensor keys equals to the table key_dtype
110 // - DataType of the tensor values equals to the table value_dtype
111 // - the values tensor has the required shape given keys and the tables's
112 // value shape.
113 virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys,
114 const Tensor& values);
115
116 // Similar to the function above but instead checks eligibility for the Import
117 // function.
118 virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
119 const Tensor& values);
120
121 // Check format of the key tensor for the Remove function.
122 // Returns OK if all the following requirements are satisfied, otherwise it
123 // returns InvalidArgument:
124 // - DataType of the tensor keys equals to the table key_dtype
125 virtual Status CheckKeyTensorForRemove(const Tensor& keys);
126
127 // Check the arguments of a find operation. Returns OK if all the following
128 // requirements are satisfied, otherwise it returns InvalidArgument:
129 // - DataType of the tensor keys equals to the table key_dtype
130 // - DataType of the tensor default_value equals to the table value_dtype
131 // - the default_value tensor has the required shape given keys and the
132 // tables's value shape.
133 Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
134
135 string DebugString() const override {
136 return strings::StrCat("A lookup table of size: ", size());
137 }
138
139 // Returns an InitializableLookupTable, a subclass of LookupInterface, if the
140 // current object is an InitializableLookupTable. Otherwise, returns nullptr.
141 virtual InitializableLookupTable* GetInitializableLookupTable() {
142 return nullptr;
143 }
144
145 protected:
146 virtual ~LookupInterface() = default;
147
148 // Makes sure that the key and value tensor DataType's match the table
149 // key_dtype and value_dtype.
150 Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values);
151
152 // Makes sure that the provided shape is consistent with the table keys shape.
153 Status CheckKeyShape(const TensorShape& shape);
154
155 private:
156 Status CheckKeyAndValueTensorsHelper(const Tensor& keys,
157 const Tensor& values);
158};
159
160} // namespace lookup
161} // namespace tensorflow
162
163#endif // TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
164