1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17
18#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19#include "tensorflow/core/framework/kernel_def_builder.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/framework/types.h"
23#include "tensorflow/core/kernels/lookup_table_init_op.h"
24#include "tensorflow/core/kernels/lookup_table_op.h"
25#include "tensorflow/core/lib/core/errors.h"
26#include "tensorflow/core/lib/core/status.h"
27
28namespace tensorflow {
29namespace {
30// lookup::InitializeTableFromTextFile requires a delimiter even though we use
31// the entire line for vocabularies.
32constexpr char kUnusedLookupDelim = '\t';
33} // namespace
34
35// This Op generates a vocab remapping Tensor from an old and new vocabulary
36// file that maps new ID's to old ID's.
37class GenerateVocabRemappingOp : public OpKernel {
38 public:
39 explicit GenerateVocabRemappingOp(OpKernelConstruction* context)
40 : OpKernel(context) {
41 OP_REQUIRES_OK(context,
42 context->GetAttr("new_vocab_offset", &new_vocab_offset_));
43 OP_REQUIRES_OK(context, context->GetAttr("num_new_vocab", &num_new_vocab_));
44 OP_REQUIRES_OK(context,
45 context->GetAttr("old_vocab_size", &old_vocab_size_));
46 }
47
48 void Compute(OpKernelContext* context) override {
49 const Tensor* new_vocab_file_tensor;
50 OP_REQUIRES_OK(context,
51 context->input("new_vocab_file", &new_vocab_file_tensor));
52 OP_REQUIRES(context,
53 TensorShapeUtils::IsScalar(new_vocab_file_tensor->shape()),
54 errors::InvalidArgument(
55 "new_vocab_file should be a single string, but got ",
56 new_vocab_file_tensor->shape().DebugString()));
57
58 // Build a new ID->token lookup table.
59 const string& new_vocab_filename =
60 new_vocab_file_tensor->scalar<tstring>()();
61 OP_REQUIRES(context, !new_vocab_filename.empty(),
62 errors::InvalidArgument("new vocab filename cannot be empty."));
63 lookup::HashTable<int64_t, tstring>* new_vocab_table =
64 new lookup::HashTable<int64_t, tstring>(context, this);
65 core::ScopedUnref unref_new(new_vocab_table);
66 // Note: we pass -1 (unknown) for vocab_size, which is supposed to be the
67 // total elements in file. This is different from num_new_vocab_, which
68 // accounts for partitioning.
69 OP_REQUIRES_OK(context, lookup::InitializeTableFromTextFile(
70 new_vocab_filename,
71 -1, // vocab_size
72 kUnusedLookupDelim,
73 -1, // key_index, use the line number.
74 -2, // value_index, use the whole line/token.
75 0, // No offset.
76 context->env(), new_vocab_table));
77 OP_REQUIRES(context,
78 new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(),
79 errors::InvalidArgument("lookup table size must be larger than "
80 "last new vocab entry's line"));
81
82 const Tensor* old_vocab_file_tensor;
83 OP_REQUIRES_OK(context,
84 context->input("old_vocab_file", &old_vocab_file_tensor));
85 OP_REQUIRES(context,
86 TensorShapeUtils::IsScalar(old_vocab_file_tensor->shape()),
87 errors::InvalidArgument(
88 "old_vocab_file should be a single string, but got ",
89 old_vocab_file_tensor->shape().DebugString()));
90 // Build a token->old ID lookup table.
91 const string& old_vocab_filename =
92 old_vocab_file_tensor->scalar<tstring>()();
93 OP_REQUIRES(context, !old_vocab_filename.empty(),
94 errors::InvalidArgument("new vocab filename cannot be empty."));
95 lookup::HashTable<tstring, int64_t>* old_vocab_table =
96 new lookup::HashTable<tstring, int64_t>(context, this);
97 core::ScopedUnref unref_old(old_vocab_table);
98 // Note: If old_vocab_size_ is -1 (unknown), we retrieve all elements in
99 // file (see TextFileLineIterator).
100 OP_REQUIRES_OK(context,
101 lookup::InitializeTableFromTextFile(
102 old_vocab_filename, old_vocab_size_, kUnusedLookupDelim,
103 -2, // key_index, use the whole line/token.
104 -1, // value_index, use the line number.
105 0, // No offset.
106 context->env(), old_vocab_table));
107
108 // Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,
109 // new_vocab_offset + num_new_vocab_]
110 // The double look-up requires a few temporary Tensors.
111 Tensor new_ids;
112 OP_REQUIRES_OK(
113 context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}),
114 &new_ids));
115 auto new_ids_vec = new_ids.vec<int64_t>();
116 // Note that we should always be able to find tokens for all new ID's, given
117 // that the lookup table is constructed with the vocabulary file itself
118 // (see the check on offset and table size post-initialization).
119 Tensor default_token;
120 OP_REQUIRES_OK(
121 context, context->allocate_temp(
122 DT_STRING, TensorShape({num_new_vocab_}), &default_token));
123 auto default_token_vec = default_token.vec<tstring>();
124 default_token_vec.setConstant("" /* NOT_FOUND_TOKEN */);
125
126 Tensor default_id;
127 OP_REQUIRES_OK(
128 context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}),
129 &default_id));
130 auto default_id_vec = default_id.vec<int64_t>();
131 default_id_vec.setConstant(-1 /* NOT_FOUND_ID */);
132
133 for (int i = 0; i < num_new_vocab_; ++i) {
134 new_ids_vec(i) = static_cast<int64_t>(i + new_vocab_offset_);
135 }
136 Tensor tokens;
137 OP_REQUIRES_OK(context,
138 context->allocate_temp(
139 DT_STRING, TensorShape({num_new_vocab_}), &tokens));
140 Tensor* remapping;
141 OP_REQUIRES_OK(context,
142 context->allocate_output(
143 "remapping", TensorShape({num_new_vocab_}), &remapping));
144 // In the corner case where num_new_vocab_ is 0 (we are dealing with an
145 // OOV-only partition), we should not do this lookup.
146 if (num_new_vocab_ != 0) {
147 OP_REQUIRES_OK(context, new_vocab_table->Find(context, new_ids, &tokens,
148 default_token));
149 OP_REQUIRES_OK(context, old_vocab_table->Find(context, tokens, remapping,
150 default_id));
151 }
152 // Iterate through remapping to calculate num_present.
153 const auto remapping_vec = remapping->vec<int64_t>();
154 int num_present = 0;
155 for (int i = 0; i < num_new_vocab_; ++i) {
156 if (remapping_vec(i) != -1 /* NOT_FOUND_ID */) {
157 ++num_present;
158 }
159 }
160 Tensor* num_present_t;
161 OP_REQUIRES_OK(context,
162 context->allocate_output("num_present", TensorShape({}),
163 &num_present_t));
164 num_present_t->scalar<int>()() = num_present;
165 }
166
167 private:
168 int new_vocab_offset_;
169 int num_new_vocab_;
170 int old_vocab_size_;
171};
172
173REGISTER_KERNEL_BUILDER(Name("GenerateVocabRemapping").Device(DEVICE_CPU),
174 GenerateVocabRemappingOp);
175
176} // namespace tensorflow
177