1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/lookup_util.h" |
17 | |
18 | #include "tensorflow/core/framework/function_handle_cache.h" |
19 | #include "tensorflow/core/framework/lookup_interface.h" |
20 | #include "tensorflow/core/framework/op_requires.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/framework/tensor_shape.h" |
23 | #include "tensorflow/core/graph/graph_def_builder.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | #include "tensorflow/core/lib/io/inputbuffer.h" |
26 | #include "tensorflow/core/platform/refcount.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace lookup { |
30 | namespace { |
31 | |
32 | using InitializerSerializer = |
33 | ::tensorflow::lookup::InitializableLookupTable::InitializerSerializer; |
34 | |
35 | static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */ |
36 | static const int kLineNumber = -1; |
37 | static const int kWholeLine = -2; |
38 | |
39 | Status GetNumLinesInTextFile(Env* env, const string& vocab_file, |
40 | int64_t* num_lines) { |
41 | std::unique_ptr<RandomAccessFile> file; |
42 | TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file)); |
43 | |
44 | io::InputBuffer input_buffer(file.get(), kInputBufferSize); |
45 | string line; |
46 | Status s = input_buffer.ReadLine(&line); |
47 | int64_t next_id = 0; |
48 | while (s.ok()) { |
49 | next_id++; |
50 | s = input_buffer.ReadLine(&line); |
51 | } |
52 | if (!errors::IsOutOfRange(s)) { |
53 | return s; |
54 | } |
55 | *num_lines = next_id; |
56 | return OkStatus(); |
57 | } |
58 | |
59 | // Iterator that reads a text file. Each iteration process one line, it parses |
60 | // the line and populates the keys and values tensors used for initialization |
61 | // with a single key and corresponding value. |
62 | // |
63 | // What information of the line to populate the key or values is specified by |
64 | // providing key_index and value_index. |
65 | class TextFileLineIterator |
66 | : public InitializableLookupTable::InitTableIterator { |
67 | public: |
68 | TextFileLineIterator() |
69 | : valid_(false), |
70 | vocab_size_(-1), |
71 | status_(errors::FailedPrecondition("Not initialized" )) {} |
72 | |
73 | // Initialize iterator. |
74 | // |
75 | // Prepares the file 'filename' and sets the data types to return the keys and |
76 | // values tensors. It requires the indices of the tokens in the line given a |
77 | // delimiter to specify where to pick the data from. |
78 | // |
79 | // - Index -2 means the entire line as string. |
80 | // - Index -1 means the line number stored in int64. |
81 | // - Index >= 0 represent index (starting at zero) of the split line based on |
82 | // delimiter. |
83 | Status Init(const string& filename, int64_t vocab_size, char delimiter, |
84 | DataType key_dtype, int64_t key_index, DataType value_dtype, |
85 | int64_t value_index, int64_t offset, Env* env) { |
86 | filename_ = filename; |
87 | vocab_size_ = vocab_size; |
88 | delimiter_ = delimiter; |
89 | key_ = Tensor(key_dtype, TensorShape({})); |
90 | value_ = Tensor(value_dtype, TensorShape({})); |
91 | key_index_ = key_index; |
92 | value_index_ = value_index; |
93 | env_ = env; |
94 | |
95 | status_ = env->NewRandomAccessFile(filename_, &file_); |
96 | if (!status_.ok()) return status_; |
97 | |
98 | input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize)); |
99 | valid_ = true; |
100 | next_id_ = 0; |
101 | offset_ = offset; |
102 | ignore_split_ = std::max(key_index_, value_index_) < 0; |
103 | Next(); |
104 | return status_; |
105 | } |
106 | |
107 | void Next() override { |
108 | if (!valid_) return; |
109 | |
110 | string line; |
111 | status_ = input_buffer_->ReadLine(&line); |
112 | if (!status_.ok()) { |
113 | if (errors::IsOutOfRange(status_) && vocab_size_ != -1 && |
114 | next_id_ != vocab_size_) { |
115 | status_ = errors::InvalidArgument("Invalid vocab_size in " , filename_, |
116 | ": expected " , vocab_size_, |
117 | " but got " , next_id_); |
118 | } |
119 | valid_ = false; |
120 | return; |
121 | } |
122 | if (vocab_size_ != -1 && next_id_ >= vocab_size_) { |
123 | LOG(WARNING) << "Truncated " << filename_ << " before its end at " |
124 | << vocab_size_ << " records." ; |
125 | LOG(WARNING) << "next_id_ : " << next_id_; |
126 | status_ = errors::OutOfRange("Finished reading " , vocab_size_, |
127 | " of lines from " , filename_); |
128 | valid_ = false; |
129 | return; |
130 | } |
131 | if (line.empty()) { |
132 | status_ = errors::InvalidArgument("Invalid content in " , filename_, |
133 | ": empty line found at position " , |
134 | input_buffer_->Tell(), "." ); |
135 | valid_ = false; |
136 | return; |
137 | } |
138 | |
139 | std::vector<string> tokens; |
140 | if (!ignore_split_) { |
141 | tokens = str_util::Split(line, delimiter_); |
142 | const auto expected_size = |
143 | static_cast<size_t>(std::max(key_index_, value_index_) + 1); |
144 | if (tokens.size() < expected_size) { |
145 | status_ = errors::InvalidArgument( |
146 | "Invalid number of columns in " , filename_, " line " , next_id_, |
147 | " (" , line, ") : expected at least " , expected_size, " got " , |
148 | tokens.size()); |
149 | valid_ = false; |
150 | return; |
151 | } |
152 | } |
153 | |
154 | status_ = SetValue(line, tokens, key_index_, &key_); |
155 | if (!status_.ok()) { |
156 | valid_ = false; |
157 | return; |
158 | } |
159 | status_ = SetValue(line, tokens, value_index_, &value_); |
160 | if (!status_.ok()) { |
161 | valid_ = false; |
162 | return; |
163 | } |
164 | |
165 | next_id_++; |
166 | } |
167 | |
168 | bool Valid() const override { return valid_; } |
169 | |
170 | const Tensor& keys() const override { return key_; } |
171 | |
172 | const Tensor& values() const override { return value_; } |
173 | |
174 | Status status() const override { return status_; } |
175 | |
176 | int64_t total_size() const override { |
177 | if (vocab_size_ == -1) { |
178 | int64_t new_size = -1; |
179 | Status status = GetNumLinesInTextFile(env_, filename_, &new_size); |
180 | if (!status.ok()) { |
181 | LOG(WARNING) << "Unable to get line count: " << status; |
182 | new_size = -1; |
183 | } |
184 | *const_cast<int64_t*>(&vocab_size_) = new_size; |
185 | } |
186 | return vocab_size_; |
187 | } |
188 | |
189 | private: |
190 | Tensor key_; |
191 | Tensor value_; |
192 | bool valid_; // true if the iterator points to an existing range. |
193 | int64_t key_index_; |
194 | int64_t value_index_; |
195 | Env* env_; |
196 | int64_t next_id_; |
197 | int64_t offset_; |
198 | int64_t vocab_size_; |
199 | string filename_; |
200 | char delimiter_; |
201 | Status status_; |
202 | bool ignore_split_; |
203 | std::unique_ptr<RandomAccessFile> file_; // must outlive input_buffer_ |
204 | std::unique_ptr<io::InputBuffer> input_buffer_; |
205 | |
206 | // Set the corresponding value from line or tokens based on 'index' into the |
207 | // tensor 't'. The value is transformed to the given data type 'dtype'. |
208 | Status SetValue(const string& line, const std::vector<string>& tokens, |
209 | int64_t index, Tensor* tensor) { |
210 | if (index == kLineNumber) { |
211 | tensor->flat<int64_t>()(0) = next_id_ + offset_; |
212 | return OkStatus(); |
213 | } |
214 | const string& token = (index == kWholeLine) ? line : tokens[index]; |
215 | const DataType& dtype = tensor->dtype(); |
216 | switch (dtype) { |
217 | case DT_INT32: { |
218 | int32_t value; |
219 | if (!strings::safe_strto32(token.c_str(), &value)) { |
220 | valid_ = false; |
221 | return errors::InvalidArgument("Field " , token, " in line " , next_id_, |
222 | " is not a valid int32." ); |
223 | } |
224 | tensor->flat<int32>()(0) = value + offset_; |
225 | } break; |
226 | case DT_INT64: { |
227 | int64_t value; |
228 | if (!strings::safe_strto64(token.c_str(), &value)) { |
229 | valid_ = false; |
230 | return errors::InvalidArgument("Field " , token, " in line " , next_id_, |
231 | " is not a valid int64." ); |
232 | } |
233 | tensor->flat<int64_t>()(0) = value; |
234 | } break; |
235 | case DT_FLOAT: { |
236 | float value; |
237 | if (!strings::safe_strtof(token.c_str(), &value)) { |
238 | valid_ = false; |
239 | return errors::InvalidArgument("Field " , token, " in line " , next_id_, |
240 | " is not a valid float." ); |
241 | } |
242 | tensor->flat<float>()(0) = value; |
243 | } break; |
244 | case DT_DOUBLE: { |
245 | double value; |
246 | if (!strings::safe_strtod(token.c_str(), &value)) { |
247 | valid_ = false; |
248 | return errors::InvalidArgument("Field " , token, " in line " , next_id_, |
249 | " is not a valid double." ); |
250 | } |
251 | tensor->flat<double>()(0) = value; |
252 | } break; |
253 | case DT_STRING: |
254 | tensor->flat<tstring>()(0) = token; |
255 | break; |
256 | default: |
257 | valid_ = false; |
258 | return errors::InvalidArgument("Data type " , DataTypeString(dtype), |
259 | " not supported." ); |
260 | } |
261 | return OkStatus(); |
262 | } |
263 | |
264 | TF_DISALLOW_COPY_AND_ASSIGN(TextFileLineIterator); |
265 | }; |
266 | |
267 | Status GetTableHandle(StringPiece input_name, OpKernelContext* ctx, |
268 | string* container, string* table_handle) { |
269 | { |
270 | mutex* mu; |
271 | TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); |
272 | mutex_lock l(*mu); |
273 | Tensor tensor; |
274 | TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); |
275 | if (tensor.NumElements() != 2) { |
276 | return errors::InvalidArgument( |
277 | "Lookup table handle must be scalar, but had shape: " , |
278 | tensor.shape().DebugString()); |
279 | } |
280 | auto h = tensor.flat<tstring>(); |
281 | *container = h(0); |
282 | *table_handle = h(1); |
283 | } |
284 | return OkStatus(); |
285 | } |
286 | |
287 | } // namespace |
288 | |
289 | Status GetResourceLookupTable(StringPiece input_name, OpKernelContext* ctx, |
290 | LookupInterface** table) { |
291 | const Tensor* handle_tensor; |
292 | TF_RETURN_IF_ERROR(ctx->input(input_name, &handle_tensor)); |
293 | const ResourceHandle& handle = handle_tensor->scalar<ResourceHandle>()(); |
294 | return LookupResource(ctx, handle, table); |
295 | } |
296 | |
297 | Status GetReferenceLookupTable(StringPiece input_name, OpKernelContext* ctx, |
298 | LookupInterface** table) { |
299 | string container; |
300 | string table_handle; |
301 | TF_RETURN_IF_ERROR( |
302 | GetTableHandle(input_name, ctx, &container, &table_handle)); |
303 | return ctx->resource_manager()->Lookup(container, table_handle, table); |
304 | } |
305 | |
306 | Status GetLookupTable(StringPiece input_name, OpKernelContext* ctx, |
307 | LookupInterface** table) { |
308 | DataType handle_dtype; |
309 | TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); |
310 | if (handle_dtype == DT_RESOURCE) { |
311 | return GetResourceLookupTable(input_name, ctx, table); |
312 | } else { |
313 | return GetReferenceLookupTable(input_name, ctx, table); |
314 | } |
315 | } |
316 | |
317 | Status GetInitializableLookupTable(StringPiece input_name, OpKernelContext* ctx, |
318 | InitializableLookupTable** table) { |
319 | LookupInterface* lookup_table; |
320 | DataType handle_dtype; |
321 | TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); |
322 | if (handle_dtype == DT_RESOURCE) { |
323 | ResourceHandle handle; |
324 | TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); |
325 | TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table)); |
326 | *table = lookup_table->GetInitializableLookupTable(); |
327 | if (*table == nullptr) { |
328 | lookup_table->Unref(); |
329 | return errors::InvalidArgument("Table " , handle.container(), " " , |
330 | handle.name(), " is not initializable" ); |
331 | } |
332 | } else { |
333 | string container; |
334 | string table_handle; |
335 | TF_RETURN_IF_ERROR( |
336 | GetTableHandle(input_name, ctx, &container, &table_handle)); |
337 | TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle, |
338 | &lookup_table)); |
339 | *table = lookup_table->GetInitializableLookupTable(); |
340 | if (*table == nullptr) { |
341 | lookup_table->Unref(); |
342 | return errors::InvalidArgument("Table " , container, " " , table_handle, |
343 | " is not initializable" ); |
344 | } |
345 | } |
346 | return OkStatus(); |
347 | } |
348 | |
349 | Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, |
350 | DataType value_dtype, const string& table_name) { |
351 | if (table.key_dtype() != key_dtype || table.value_dtype() != value_dtype) { |
352 | return errors::InvalidArgument( |
353 | "Conflicting key/value dtypes " , DataTypeString(key_dtype), "->" , |
354 | DataTypeString(value_dtype), " with " , |
355 | DataTypeString(table.key_dtype()), "-" , |
356 | DataTypeString(table.value_dtype()), " for table " , table_name); |
357 | } |
358 | return OkStatus(); |
359 | } |
360 | |
361 | // Helper function to initialize an InitializableLookupTable from a text file. |
362 | Status InitializeTableFromTextFile(const string& filename, int64_t vocab_size, |
363 | char delimiter, int32_t key_index, |
364 | int32_t value_index, int64_t offset, |
365 | Env* env, InitializableLookupTable* table) { |
366 | return InitializeTableFromTextFile(filename, vocab_size, delimiter, key_index, |
367 | value_index, offset, env, |
368 | /*serializer=*/nullptr, table); |
369 | } |
370 | |
371 | Status InitializeTableFromTextFile( |
372 | const string& filename, int64_t vocab_size, char delimiter, |
373 | int32_t key_index, int32_t value_index, int64_t offset, Env* env, |
374 | std::unique_ptr<InitializableLookupTable::InitializerSerializer> serializer, |
375 | InitializableLookupTable* table) { |
376 | if (key_index == kLineNumber && table->key_dtype() != DT_INT64) { |
377 | return errors::InvalidArgument( |
378 | "Key index for line number requires table key dtype of int64, got " , |
379 | DataTypeString(table->key_dtype())); |
380 | } |
381 | const DataType& key_dtype = table->key_dtype(); |
382 | const DataType& value_dtype = table->value_dtype(); |
383 | if (key_index == kWholeLine && !DataTypeIsInteger(key_dtype) && |
384 | key_dtype != DT_STRING) { |
385 | return errors::InvalidArgument( |
386 | "Key index for whole line requires string or integer table key, got " , |
387 | DataTypeString(table->key_dtype())); |
388 | } |
389 | if (value_index == kLineNumber && value_dtype != DT_INT64) { |
390 | return errors::InvalidArgument( |
391 | "Value index for line number requires table value dtype of int64, got " , |
392 | DataTypeString(table->value_dtype())); |
393 | } |
394 | if (value_index == kWholeLine && !DataTypeIsInteger(value_dtype) && |
395 | value_dtype != DT_STRING) { |
396 | return errors::InvalidArgument( |
397 | "Value index for whole line requires table value dtype of integer or " |
398 | "string, got " , |
399 | DataTypeString(table->value_dtype())); |
400 | } |
401 | |
402 | TextFileLineIterator iter; |
403 | TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype, |
404 | key_index, value_dtype, value_index, offset, |
405 | env)); |
406 | // For initialization from files, ignore if the table is already |
407 | // initialized. The table shared name should contain the filename to |
408 | // avoid trying to initialize the same table from the same file at the same |
409 | // time. |
410 | Status s = table->Initialize(iter, std::move(serializer)); |
411 | if (errors::IsFailedPrecondition(s) && table->is_initialized()) { |
412 | LOG(INFO) << "Table trying to initialize from file " << filename |
413 | << " is already initialized." ; |
414 | return OkStatus(); |
415 | } |
416 | return s; |
417 | } |
418 | |
419 | } // namespace lookup |
420 | } // namespace tensorflow |
421 | |