1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | |
18 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
19 | #include "tensorflow/core/framework/kernel_def_builder.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | #include "tensorflow/core/kernels/lookup_table_init_op.h" |
24 | #include "tensorflow/core/kernels/lookup_table_op.h" |
25 | #include "tensorflow/core/lib/core/errors.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace { |
30 | // lookup::InitializeTableFromTextFile requires a delimiter even though we use |
31 | // the entire line for vocabularies. |
32 | constexpr char kUnusedLookupDelim = '\t'; |
33 | } // namespace |
34 | |
35 | // This Op generates a vocab remapping Tensor from an old and new vocabulary |
36 | // file that maps new ID's to old ID's. |
37 | class GenerateVocabRemappingOp : public OpKernel { |
38 | public: |
39 | explicit GenerateVocabRemappingOp(OpKernelConstruction* context) |
40 | : OpKernel(context) { |
41 | OP_REQUIRES_OK(context, |
42 | context->GetAttr("new_vocab_offset" , &new_vocab_offset_)); |
43 | OP_REQUIRES_OK(context, context->GetAttr("num_new_vocab" , &num_new_vocab_)); |
44 | OP_REQUIRES_OK(context, |
45 | context->GetAttr("old_vocab_size" , &old_vocab_size_)); |
46 | } |
47 | |
48 | void Compute(OpKernelContext* context) override { |
49 | const Tensor* new_vocab_file_tensor; |
50 | OP_REQUIRES_OK(context, |
51 | context->input("new_vocab_file" , &new_vocab_file_tensor)); |
52 | OP_REQUIRES(context, |
53 | TensorShapeUtils::IsScalar(new_vocab_file_tensor->shape()), |
54 | errors::InvalidArgument( |
55 | "new_vocab_file should be a single string, but got " , |
56 | new_vocab_file_tensor->shape().DebugString())); |
57 | |
58 | // Build a new ID->token lookup table. |
59 | const string& new_vocab_filename = |
60 | new_vocab_file_tensor->scalar<tstring>()(); |
61 | OP_REQUIRES(context, !new_vocab_filename.empty(), |
62 | errors::InvalidArgument("new vocab filename cannot be empty." )); |
63 | lookup::HashTable<int64_t, tstring>* new_vocab_table = |
64 | new lookup::HashTable<int64_t, tstring>(context, this); |
65 | core::ScopedUnref unref_new(new_vocab_table); |
66 | // Note: we pass -1 (unknown) for vocab_size, which is supposed to be the |
67 | // total elements in file. This is different from num_new_vocab_, which |
68 | // accounts for partitioning. |
69 | OP_REQUIRES_OK(context, lookup::InitializeTableFromTextFile( |
70 | new_vocab_filename, |
71 | -1, // vocab_size |
72 | kUnusedLookupDelim, |
73 | -1, // key_index, use the line number. |
74 | -2, // value_index, use the whole line/token. |
75 | 0, // No offset. |
76 | context->env(), new_vocab_table)); |
77 | OP_REQUIRES(context, |
78 | new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(), |
79 | errors::InvalidArgument("lookup table size must be larger than " |
80 | "last new vocab entry's line" )); |
81 | |
82 | const Tensor* old_vocab_file_tensor; |
83 | OP_REQUIRES_OK(context, |
84 | context->input("old_vocab_file" , &old_vocab_file_tensor)); |
85 | OP_REQUIRES(context, |
86 | TensorShapeUtils::IsScalar(old_vocab_file_tensor->shape()), |
87 | errors::InvalidArgument( |
88 | "old_vocab_file should be a single string, but got " , |
89 | old_vocab_file_tensor->shape().DebugString())); |
90 | // Build a token->old ID lookup table. |
91 | const string& old_vocab_filename = |
92 | old_vocab_file_tensor->scalar<tstring>()(); |
93 | OP_REQUIRES(context, !old_vocab_filename.empty(), |
94 | errors::InvalidArgument("new vocab filename cannot be empty." )); |
95 | lookup::HashTable<tstring, int64_t>* old_vocab_table = |
96 | new lookup::HashTable<tstring, int64_t>(context, this); |
97 | core::ScopedUnref unref_old(old_vocab_table); |
98 | // Note: If old_vocab_size_ is -1 (unknown), we retrieve all elements in |
99 | // file (see TextFileLineIterator). |
100 | OP_REQUIRES_OK(context, |
101 | lookup::InitializeTableFromTextFile( |
102 | old_vocab_filename, old_vocab_size_, kUnusedLookupDelim, |
103 | -2, // key_index, use the whole line/token. |
104 | -1, // value_index, use the line number. |
105 | 0, // No offset. |
106 | context->env(), old_vocab_table)); |
107 | |
108 | // Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ..., |
109 | // new_vocab_offset + num_new_vocab_] |
110 | // The double look-up requires a few temporary Tensors. |
111 | Tensor new_ids; |
112 | OP_REQUIRES_OK( |
113 | context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}), |
114 | &new_ids)); |
115 | auto new_ids_vec = new_ids.vec<int64_t>(); |
116 | // Note that we should always be able to find tokens for all new ID's, given |
117 | // that the lookup table is constructed with the vocabulary file itself |
118 | // (see the check on offset and table size post-initialization). |
119 | Tensor default_token; |
120 | OP_REQUIRES_OK( |
121 | context, context->allocate_temp( |
122 | DT_STRING, TensorShape({num_new_vocab_}), &default_token)); |
123 | auto default_token_vec = default_token.vec<tstring>(); |
124 | default_token_vec.setConstant("" /* NOT_FOUND_TOKEN */); |
125 | |
126 | Tensor default_id; |
127 | OP_REQUIRES_OK( |
128 | context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}), |
129 | &default_id)); |
130 | auto default_id_vec = default_id.vec<int64_t>(); |
131 | default_id_vec.setConstant(-1 /* NOT_FOUND_ID */); |
132 | |
133 | for (int i = 0; i < num_new_vocab_; ++i) { |
134 | new_ids_vec(i) = static_cast<int64_t>(i + new_vocab_offset_); |
135 | } |
136 | Tensor tokens; |
137 | OP_REQUIRES_OK(context, |
138 | context->allocate_temp( |
139 | DT_STRING, TensorShape({num_new_vocab_}), &tokens)); |
140 | Tensor* remapping; |
141 | OP_REQUIRES_OK(context, |
142 | context->allocate_output( |
143 | "remapping" , TensorShape({num_new_vocab_}), &remapping)); |
144 | // In the corner case where num_new_vocab_ is 0 (we are dealing with an |
145 | // OOV-only partition), we should not do this lookup. |
146 | if (num_new_vocab_ != 0) { |
147 | OP_REQUIRES_OK(context, new_vocab_table->Find(context, new_ids, &tokens, |
148 | default_token)); |
149 | OP_REQUIRES_OK(context, old_vocab_table->Find(context, tokens, remapping, |
150 | default_id)); |
151 | } |
152 | // Iterate through remapping to calculate num_present. |
153 | const auto remapping_vec = remapping->vec<int64_t>(); |
154 | int num_present = 0; |
155 | for (int i = 0; i < num_new_vocab_; ++i) { |
156 | if (remapping_vec(i) != -1 /* NOT_FOUND_ID */) { |
157 | ++num_present; |
158 | } |
159 | } |
160 | Tensor* num_present_t; |
161 | OP_REQUIRES_OK(context, |
162 | context->allocate_output("num_present" , TensorShape({}), |
163 | &num_present_t)); |
164 | num_present_t->scalar<int>()() = num_present; |
165 | } |
166 | |
167 | private: |
168 | int new_vocab_offset_; |
169 | int num_new_vocab_; |
170 | int old_vocab_size_; |
171 | }; |
172 | |
173 | REGISTER_KERNEL_BUILDER(Name("GenerateVocabRemapping" ).Device(DEVICE_CPU), |
174 | GenerateVocabRemappingOp); |
175 | |
176 | } // namespace tensorflow |
177 | |