1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #define EIGEN_USE_THREADS |
16 | |
17 | #include "tensorflow/core/kernels/lookup_table_init_op.h" |
18 | |
19 | #include <algorithm> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/register_types.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/graph/graph_def_builder.h" |
30 | #include "tensorflow/core/kernels/lookup_util.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/lib/core/status.h" |
33 | #include "tensorflow/core/lib/io/inputbuffer.h" |
34 | #include "tensorflow/core/lib/strings/numbers.h" |
35 | #include "tensorflow/core/lib/strings/str_util.h" |
36 | #include "tensorflow/core/platform/macros.h" |
37 | |
38 | namespace tensorflow { |
39 | |
40 | using InitializerSerializer = |
41 | lookup::InitializableLookupTable::InitializerSerializer; |
42 | |
43 | // Kernel to initialize a look table given a key and value tensors. |
44 | // After this operation, the table becomes read-only. |
45 | class InitializeTableOp : public OpKernel { |
46 | public: |
47 | explicit InitializeTableOp(OpKernelConstruction* context) |
48 | : OpKernel(context) {} |
49 | |
50 | void Compute(OpKernelContext* ctx) override { |
51 | mutex_lock l(mu_); |
52 | lookup::InitializableLookupTable* table; |
53 | OP_REQUIRES_OK(ctx, |
54 | GetInitializableLookupTable("table_handle" , ctx, &table)); |
55 | core::ScopedUnref unref_me(table); |
56 | |
57 | DataType expected_input_0 = |
58 | (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; |
59 | DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), |
60 | table->value_dtype()}; |
61 | DataTypeVector expected_outputs = {}; |
62 | OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); |
63 | |
64 | const Tensor& keys = ctx->input(1); |
65 | OP_REQUIRES( |
66 | ctx, TensorShapeUtils::IsVector(keys.shape()), |
67 | errors::InvalidArgument("Keys must be a vector, but received shape" , |
68 | keys.shape().DebugString())); |
69 | |
70 | const Tensor& values = ctx->input(2); |
71 | OP_REQUIRES( |
72 | ctx, TensorShapeUtils::IsVector(values.shape()), |
73 | errors::InvalidArgument("Values must be a vector, but received shape" , |
74 | values.shape().DebugString())); |
75 | |
76 | OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(), |
77 | errors::InvalidArgument( |
78 | "Keys and values must have the same size " , |
79 | keys.NumElements(), " vs " , values.NumElements())); |
80 | |
81 | int memory_used_before = 0; |
82 | if (ctx->track_allocations()) { |
83 | memory_used_before = table->MemoryUsed(); |
84 | } |
85 | OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); |
86 | if (ctx->track_allocations()) { |
87 | ctx->record_persistent_memory_allocation(table->MemoryUsed() - |
88 | memory_used_before); |
89 | } |
90 | } |
91 | |
92 | private: |
93 | mutex mu_; |
94 | }; |
95 | |
96 | REGISTER_KERNEL_BUILDER(Name("InitializeTable" ).Device(DEVICE_CPU), |
97 | InitializeTableOp); |
98 | REGISTER_KERNEL_BUILDER(Name("InitializeTableV2" ).Device(DEVICE_CPU), |
99 | InitializeTableOp); |
100 | |
101 | // Kernel to initialize a lookup table from a text file. |
102 | // |
103 | // After this operation, the table becomes read-only. |
104 | class InitializeTableFromTextFileOp : public OpKernel { |
105 | public: |
106 | explicit InitializeTableFromTextFileOp(OpKernelConstruction* ctx) |
107 | : OpKernel(ctx) { |
108 | OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size" , &vocab_size_)); |
109 | OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index" , &key_index_)); |
110 | OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index" , &value_index_)); |
111 | if (ctx->HasAttr("offset" )) { |
112 | OP_REQUIRES_OK(ctx, ctx->GetAttr("offset" , &offset_)); |
113 | } |
114 | string delimiter; |
115 | OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter" , &delimiter)); |
116 | OP_REQUIRES(ctx, delimiter.size() == 1, |
117 | errors::InvalidArgument("delimiter should be only 1 char" )); |
118 | delimiter_ = delimiter[0]; |
119 | } |
120 | |
121 | void Compute(OpKernelContext* ctx) override { |
122 | mutex_lock l(mu_); |
123 | lookup::InitializableLookupTable* table; |
124 | OP_REQUIRES_OK(ctx, |
125 | GetInitializableLookupTable("table_handle" , ctx, &table)); |
126 | core::ScopedUnref unref_me(table); |
127 | |
128 | DataType expected_input_0 = |
129 | (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; |
130 | DataTypeVector expected_inputs = {expected_input_0, DT_STRING}; |
131 | DataTypeVector expected_outputs = {}; |
132 | OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); |
133 | |
134 | const Tensor& vocab_filename_tensor = ctx->input(1); |
135 | OP_REQUIRES( |
136 | ctx, TensorShapeUtils::IsScalar(vocab_filename_tensor.shape()), |
137 | errors::InvalidArgument("filename should be a single string, but got " , |
138 | vocab_filename_tensor.shape().DebugString())); |
139 | |
140 | const string& vocab_filename = vocab_filename_tensor.scalar<tstring>()(); |
141 | OP_REQUIRES(ctx, !vocab_filename.empty(), |
142 | errors::InvalidArgument("filename cannot be empty." )); |
143 | |
144 | int64_t memory_used_before = 0; |
145 | if (ctx->track_allocations()) { |
146 | memory_used_before = table->MemoryUsed(); |
147 | } |
148 | OP_REQUIRES_OK( |
149 | ctx, lookup::InitializeTableFromTextFile( |
150 | vocab_filename, vocab_size_, delimiter_, key_index_, |
151 | value_index_, offset_, ctx->env(), |
152 | MakeInitializerSerializer(vocab_filename_tensor), table)); |
153 | if (ctx->track_allocations()) { |
154 | ctx->record_persistent_memory_allocation(table->MemoryUsed() - |
155 | memory_used_before); |
156 | } |
157 | } |
158 | |
159 | private: |
160 | std::unique_ptr<InitializerSerializer> MakeInitializerSerializer( |
161 | Tensor vocab_filename) { |
162 | return std::make_unique<InitializerSerializer>( |
163 | [vocab_filename, vocab_size = vocab_size_, delimiter = delimiter_, |
164 | key_index = key_index_, value_index = value_index_, |
165 | offset = offset_](GraphDefBuilder* builder, Node* table, Node** out) { |
166 | Node* vocab_filename_node = ops::SourceOp( |
167 | "Const" , builder->opts() |
168 | .WithAttr("dtype" , vocab_filename.dtype()) |
169 | .WithAttr("value" , vocab_filename)); |
170 | std::string delimiter_string(1, delimiter); |
171 | Node* import_table = ops::BinaryOp( |
172 | "InitializeTableFromTextFileV2" , table, vocab_filename_node, |
173 | builder->opts() |
174 | .WithAttr("vocab_size" , vocab_size) |
175 | .WithAttr("key_index" , key_index) |
176 | .WithAttr("value_index" , value_index) |
177 | .WithAttr("offset" , offset) |
178 | .WithAttr("delimiter" , delimiter_string)); |
179 | *out = ops::UnaryOp("Identity" , table, |
180 | builder->opts().WithControlInput(import_table)); |
181 | return OkStatus(); |
182 | }); |
183 | } |
184 | |
185 | mutex mu_; |
186 | int64_t vocab_size_; |
187 | char delimiter_; |
188 | int64_t key_index_; |
189 | int64_t value_index_; |
190 | int64_t offset_ = 0; |
191 | |
192 | TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp); |
193 | }; |
194 | |
195 | REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile" ).Device(DEVICE_CPU), |
196 | InitializeTableFromTextFileOp); |
197 | REGISTER_KERNEL_BUILDER( |
198 | Name("InitializeTableFromTextFileV2" ).Device(DEVICE_CPU), |
199 | InitializeTableFromTextFileOp); |
200 | } // namespace tensorflow |
201 | |