1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/lookup_util.h"
17
18#include "tensorflow/core/framework/function_handle_cache.h"
19#include "tensorflow/core/framework/lookup_interface.h"
20#include "tensorflow/core/framework/op_requires.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/framework/tensor_shape.h"
23#include "tensorflow/core/graph/graph_def_builder.h"
24#include "tensorflow/core/lib/core/errors.h"
25#include "tensorflow/core/lib/io/inputbuffer.h"
26#include "tensorflow/core/platform/refcount.h"
27
28namespace tensorflow {
29namespace lookup {
30namespace {
31
32using InitializerSerializer =
33 ::tensorflow::lookup::InitializableLookupTable::InitializerSerializer;
34
35static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
36static const int kLineNumber = -1;
37static const int kWholeLine = -2;
38
39Status GetNumLinesInTextFile(Env* env, const string& vocab_file,
40 int64_t* num_lines) {
41 std::unique_ptr<RandomAccessFile> file;
42 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
43
44 io::InputBuffer input_buffer(file.get(), kInputBufferSize);
45 string line;
46 Status s = input_buffer.ReadLine(&line);
47 int64_t next_id = 0;
48 while (s.ok()) {
49 next_id++;
50 s = input_buffer.ReadLine(&line);
51 }
52 if (!errors::IsOutOfRange(s)) {
53 return s;
54 }
55 *num_lines = next_id;
56 return OkStatus();
57}
58
59// Iterator that reads a text file. Each iteration process one line, it parses
60// the line and populates the keys and values tensors used for initialization
61// with a single key and corresponding value.
62//
63// What information of the line to populate the key or values is specified by
64// providing key_index and value_index.
65class TextFileLineIterator
66 : public InitializableLookupTable::InitTableIterator {
67 public:
68 TextFileLineIterator()
69 : valid_(false),
70 vocab_size_(-1),
71 status_(errors::FailedPrecondition("Not initialized")) {}
72
73 // Initialize iterator.
74 //
75 // Prepares the file 'filename' and sets the data types to return the keys and
76 // values tensors. It requires the indices of the tokens in the line given a
77 // delimiter to specify where to pick the data from.
78 //
79 // - Index -2 means the entire line as string.
80 // - Index -1 means the line number stored in int64.
81 // - Index >= 0 represent index (starting at zero) of the split line based on
82 // delimiter.
83 Status Init(const string& filename, int64_t vocab_size, char delimiter,
84 DataType key_dtype, int64_t key_index, DataType value_dtype,
85 int64_t value_index, int64_t offset, Env* env) {
86 filename_ = filename;
87 vocab_size_ = vocab_size;
88 delimiter_ = delimiter;
89 key_ = Tensor(key_dtype, TensorShape({}));
90 value_ = Tensor(value_dtype, TensorShape({}));
91 key_index_ = key_index;
92 value_index_ = value_index;
93 env_ = env;
94
95 status_ = env->NewRandomAccessFile(filename_, &file_);
96 if (!status_.ok()) return status_;
97
98 input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize));
99 valid_ = true;
100 next_id_ = 0;
101 offset_ = offset;
102 ignore_split_ = std::max(key_index_, value_index_) < 0;
103 Next();
104 return status_;
105 }
106
107 void Next() override {
108 if (!valid_) return;
109
110 string line;
111 status_ = input_buffer_->ReadLine(&line);
112 if (!status_.ok()) {
113 if (errors::IsOutOfRange(status_) && vocab_size_ != -1 &&
114 next_id_ != vocab_size_) {
115 status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_,
116 ": expected ", vocab_size_,
117 " but got ", next_id_);
118 }
119 valid_ = false;
120 return;
121 }
122 if (vocab_size_ != -1 && next_id_ >= vocab_size_) {
123 LOG(WARNING) << "Truncated " << filename_ << " before its end at "
124 << vocab_size_ << " records.";
125 LOG(WARNING) << "next_id_ : " << next_id_;
126 status_ = errors::OutOfRange("Finished reading ", vocab_size_,
127 " of lines from ", filename_);
128 valid_ = false;
129 return;
130 }
131 if (line.empty()) {
132 status_ = errors::InvalidArgument("Invalid content in ", filename_,
133 ": empty line found at position ",
134 input_buffer_->Tell(), ".");
135 valid_ = false;
136 return;
137 }
138
139 std::vector<string> tokens;
140 if (!ignore_split_) {
141 tokens = str_util::Split(line, delimiter_);
142 const auto expected_size =
143 static_cast<size_t>(std::max(key_index_, value_index_) + 1);
144 if (tokens.size() < expected_size) {
145 status_ = errors::InvalidArgument(
146 "Invalid number of columns in ", filename_, " line ", next_id_,
147 " (", line, ") : expected at least ", expected_size, " got ",
148 tokens.size());
149 valid_ = false;
150 return;
151 }
152 }
153
154 status_ = SetValue(line, tokens, key_index_, &key_);
155 if (!status_.ok()) {
156 valid_ = false;
157 return;
158 }
159 status_ = SetValue(line, tokens, value_index_, &value_);
160 if (!status_.ok()) {
161 valid_ = false;
162 return;
163 }
164
165 next_id_++;
166 }
167
168 bool Valid() const override { return valid_; }
169
170 const Tensor& keys() const override { return key_; }
171
172 const Tensor& values() const override { return value_; }
173
174 Status status() const override { return status_; }
175
176 int64_t total_size() const override {
177 if (vocab_size_ == -1) {
178 int64_t new_size = -1;
179 Status status = GetNumLinesInTextFile(env_, filename_, &new_size);
180 if (!status.ok()) {
181 LOG(WARNING) << "Unable to get line count: " << status;
182 new_size = -1;
183 }
184 *const_cast<int64_t*>(&vocab_size_) = new_size;
185 }
186 return vocab_size_;
187 }
188
189 private:
190 Tensor key_;
191 Tensor value_;
192 bool valid_; // true if the iterator points to an existing range.
193 int64_t key_index_;
194 int64_t value_index_;
195 Env* env_;
196 int64_t next_id_;
197 int64_t offset_;
198 int64_t vocab_size_;
199 string filename_;
200 char delimiter_;
201 Status status_;
202 bool ignore_split_;
203 std::unique_ptr<RandomAccessFile> file_; // must outlive input_buffer_
204 std::unique_ptr<io::InputBuffer> input_buffer_;
205
206 // Set the corresponding value from line or tokens based on 'index' into the
207 // tensor 't'. The value is transformed to the given data type 'dtype'.
208 Status SetValue(const string& line, const std::vector<string>& tokens,
209 int64_t index, Tensor* tensor) {
210 if (index == kLineNumber) {
211 tensor->flat<int64_t>()(0) = next_id_ + offset_;
212 return OkStatus();
213 }
214 const string& token = (index == kWholeLine) ? line : tokens[index];
215 const DataType& dtype = tensor->dtype();
216 switch (dtype) {
217 case DT_INT32: {
218 int32_t value;
219 if (!strings::safe_strto32(token.c_str(), &value)) {
220 valid_ = false;
221 return errors::InvalidArgument("Field ", token, " in line ", next_id_,
222 " is not a valid int32.");
223 }
224 tensor->flat<int32>()(0) = value + offset_;
225 } break;
226 case DT_INT64: {
227 int64_t value;
228 if (!strings::safe_strto64(token.c_str(), &value)) {
229 valid_ = false;
230 return errors::InvalidArgument("Field ", token, " in line ", next_id_,
231 " is not a valid int64.");
232 }
233 tensor->flat<int64_t>()(0) = value;
234 } break;
235 case DT_FLOAT: {
236 float value;
237 if (!strings::safe_strtof(token.c_str(), &value)) {
238 valid_ = false;
239 return errors::InvalidArgument("Field ", token, " in line ", next_id_,
240 " is not a valid float.");
241 }
242 tensor->flat<float>()(0) = value;
243 } break;
244 case DT_DOUBLE: {
245 double value;
246 if (!strings::safe_strtod(token.c_str(), &value)) {
247 valid_ = false;
248 return errors::InvalidArgument("Field ", token, " in line ", next_id_,
249 " is not a valid double.");
250 }
251 tensor->flat<double>()(0) = value;
252 } break;
253 case DT_STRING:
254 tensor->flat<tstring>()(0) = token;
255 break;
256 default:
257 valid_ = false;
258 return errors::InvalidArgument("Data type ", DataTypeString(dtype),
259 " not supported.");
260 }
261 return OkStatus();
262 }
263
264 TF_DISALLOW_COPY_AND_ASSIGN(TextFileLineIterator);
265};
266
267Status GetTableHandle(StringPiece input_name, OpKernelContext* ctx,
268 string* container, string* table_handle) {
269 {
270 mutex* mu;
271 TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
272 mutex_lock l(*mu);
273 Tensor tensor;
274 TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
275 if (tensor.NumElements() != 2) {
276 return errors::InvalidArgument(
277 "Lookup table handle must be scalar, but had shape: ",
278 tensor.shape().DebugString());
279 }
280 auto h = tensor.flat<tstring>();
281 *container = h(0);
282 *table_handle = h(1);
283 }
284 return OkStatus();
285}
286
287} // namespace
288
289Status GetResourceLookupTable(StringPiece input_name, OpKernelContext* ctx,
290 LookupInterface** table) {
291 const Tensor* handle_tensor;
292 TF_RETURN_IF_ERROR(ctx->input(input_name, &handle_tensor));
293 const ResourceHandle& handle = handle_tensor->scalar<ResourceHandle>()();
294 return LookupResource(ctx, handle, table);
295}
296
297Status GetReferenceLookupTable(StringPiece input_name, OpKernelContext* ctx,
298 LookupInterface** table) {
299 string container;
300 string table_handle;
301 TF_RETURN_IF_ERROR(
302 GetTableHandle(input_name, ctx, &container, &table_handle));
303 return ctx->resource_manager()->Lookup(container, table_handle, table);
304}
305
306Status GetLookupTable(StringPiece input_name, OpKernelContext* ctx,
307 LookupInterface** table) {
308 DataType handle_dtype;
309 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
310 if (handle_dtype == DT_RESOURCE) {
311 return GetResourceLookupTable(input_name, ctx, table);
312 } else {
313 return GetReferenceLookupTable(input_name, ctx, table);
314 }
315}
316
317Status GetInitializableLookupTable(StringPiece input_name, OpKernelContext* ctx,
318 InitializableLookupTable** table) {
319 LookupInterface* lookup_table;
320 DataType handle_dtype;
321 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
322 if (handle_dtype == DT_RESOURCE) {
323 ResourceHandle handle;
324 TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
325 TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table));
326 *table = lookup_table->GetInitializableLookupTable();
327 if (*table == nullptr) {
328 lookup_table->Unref();
329 return errors::InvalidArgument("Table ", handle.container(), " ",
330 handle.name(), " is not initializable");
331 }
332 } else {
333 string container;
334 string table_handle;
335 TF_RETURN_IF_ERROR(
336 GetTableHandle(input_name, ctx, &container, &table_handle));
337 TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle,
338 &lookup_table));
339 *table = lookup_table->GetInitializableLookupTable();
340 if (*table == nullptr) {
341 lookup_table->Unref();
342 return errors::InvalidArgument("Table ", container, " ", table_handle,
343 " is not initializable");
344 }
345 }
346 return OkStatus();
347}
348
349Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
350 DataType value_dtype, const string& table_name) {
351 if (table.key_dtype() != key_dtype || table.value_dtype() != value_dtype) {
352 return errors::InvalidArgument(
353 "Conflicting key/value dtypes ", DataTypeString(key_dtype), "->",
354 DataTypeString(value_dtype), " with ",
355 DataTypeString(table.key_dtype()), "-",
356 DataTypeString(table.value_dtype()), " for table ", table_name);
357 }
358 return OkStatus();
359}
360
361// Helper function to initialize an InitializableLookupTable from a text file.
362Status InitializeTableFromTextFile(const string& filename, int64_t vocab_size,
363 char delimiter, int32_t key_index,
364 int32_t value_index, int64_t offset,
365 Env* env, InitializableLookupTable* table) {
366 return InitializeTableFromTextFile(filename, vocab_size, delimiter, key_index,
367 value_index, offset, env,
368 /*serializer=*/nullptr, table);
369}
370
371Status InitializeTableFromTextFile(
372 const string& filename, int64_t vocab_size, char delimiter,
373 int32_t key_index, int32_t value_index, int64_t offset, Env* env,
374 std::unique_ptr<InitializableLookupTable::InitializerSerializer> serializer,
375 InitializableLookupTable* table) {
376 if (key_index == kLineNumber && table->key_dtype() != DT_INT64) {
377 return errors::InvalidArgument(
378 "Key index for line number requires table key dtype of int64, got ",
379 DataTypeString(table->key_dtype()));
380 }
381 const DataType& key_dtype = table->key_dtype();
382 const DataType& value_dtype = table->value_dtype();
383 if (key_index == kWholeLine && !DataTypeIsInteger(key_dtype) &&
384 key_dtype != DT_STRING) {
385 return errors::InvalidArgument(
386 "Key index for whole line requires string or integer table key, got ",
387 DataTypeString(table->key_dtype()));
388 }
389 if (value_index == kLineNumber && value_dtype != DT_INT64) {
390 return errors::InvalidArgument(
391 "Value index for line number requires table value dtype of int64, got ",
392 DataTypeString(table->value_dtype()));
393 }
394 if (value_index == kWholeLine && !DataTypeIsInteger(value_dtype) &&
395 value_dtype != DT_STRING) {
396 return errors::InvalidArgument(
397 "Value index for whole line requires table value dtype of integer or "
398 "string, got ",
399 DataTypeString(table->value_dtype()));
400 }
401
402 TextFileLineIterator iter;
403 TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
404 key_index, value_dtype, value_index, offset,
405 env));
406 // For initialization from files, ignore if the table is already
407 // initialized. The table shared name should contain the filename to
408 // avoid trying to initialize the same table from the same file at the same
409 // time.
410 Status s = table->Initialize(iter, std::move(serializer));
411 if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
412 LOG(INFO) << "Table trying to initialize from file " << filename
413 << " is already initialized.";
414 return OkStatus();
415 }
416 return s;
417}
418
419} // namespace lookup
420} // namespace tensorflow
421