1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <locale> |
18 | #include <string> |
19 | |
20 | #include "absl/strings/ascii.h" |
21 | #include "absl/strings/str_cat.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/op_requires.h" |
24 | #include "tensorflow/core/platform/errors.h" |
25 | #include "tensorflow/core/platform/types.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace text { |
29 | |
30 | namespace { |
31 | template <typename SPLITS_TYPE> |
32 | class StringNGramsOp : public tensorflow::OpKernel { |
33 | public: |
34 | explicit StringNGramsOp(tensorflow::OpKernelConstruction* context) |
35 | : tensorflow::OpKernel(context) { |
36 | OP_REQUIRES_OK(context, context->GetAttr("separator" , &separator_)); |
37 | OP_REQUIRES_OK(context, context->GetAttr("ngram_widths" , &ngram_widths_)); |
38 | OP_REQUIRES_OK(context, context->GetAttr("left_pad" , &left_pad_)); |
39 | OP_REQUIRES_OK(context, context->GetAttr("right_pad" , &right_pad_)); |
40 | OP_REQUIRES_OK(context, context->GetAttr("pad_width" , &pad_width_)); |
41 | OP_REQUIRES_OK(context, context->GetAttr("preserve_short_sequences" , |
42 | &preserve_short_)); |
43 | } |
44 | |
45 | int get_pad_width(const int ngram_width) const { |
46 | // Ngrams can be padded with either a fixed pad width or a dynamic pad |
47 | // width depending on the 'pad_width' arg, but in no case should the padding |
48 | // ever be wider than 'ngram_width' - 1. |
49 | return std::min(pad_width_ < 0 ? ngram_width - 1 : pad_width_, |
50 | ngram_width - 1); |
51 | } |
52 | |
53 | StatusOr<int> get_num_ngrams(const int length, const int ngram_width) const { |
54 | int64 limit = kint32max; |
55 | int pad_width = get_pad_width(ngram_width); |
56 | if (pad_width > limit / 2 - length) { |
57 | return errors::InvalidArgument( |
58 | "Pad width could lead to integer overflow, got pad_width = " , |
59 | pad_width); |
60 | } |
61 | return std::max(0, ((length + 2 * pad_width) - ngram_width) + 1); |
62 | } |
63 | |
64 | void Compute(tensorflow::OpKernelContext* context) override { |
65 | for (int ngram_width : ngram_widths_) { |
66 | OP_REQUIRES( |
67 | context, ngram_width > 0, |
68 | errors::InvalidArgument("ngram_widths must contain positive values" )); |
69 | } |
70 | |
71 | const tensorflow::Tensor* data; |
72 | OP_REQUIRES_OK(context, context->input("data" , &data)); |
73 | const auto& input_data = data->flat<tstring>().data(); |
74 | |
75 | const tensorflow::Tensor* splits; |
76 | OP_REQUIRES_OK(context, context->input("data_splits" , &splits)); |
77 | const auto& splits_vec = splits->flat<SPLITS_TYPE>(); |
78 | |
79 | // Validate that the splits are valid indices into data, only if there are |
80 | // splits specified. |
81 | const int input_data_size = data->flat<tstring>().size(); |
82 | const int splits_vec_size = splits_vec.size(); |
83 | if (splits_vec_size > 0) { |
84 | int prev_split = splits_vec(0); |
85 | OP_REQUIRES(context, prev_split == 0, |
86 | errors::InvalidArgument("First split value must be 0, got " , |
87 | prev_split)); |
88 | for (int i = 1; i < splits_vec_size; ++i) { |
89 | bool valid_splits = splits_vec(i) >= prev_split; |
90 | valid_splits = valid_splits && (splits_vec(i) <= input_data_size); |
91 | OP_REQUIRES(context, valid_splits, |
92 | errors::InvalidArgument( |
93 | "Invalid split value " , splits_vec(i), ", must be in [" , |
94 | prev_split, ", " , input_data_size, "]" )); |
95 | prev_split = splits_vec(i); |
96 | } |
97 | OP_REQUIRES(context, prev_split == input_data_size, |
98 | errors::InvalidArgument( |
99 | "Last split value must be data size. Expected " , |
100 | input_data_size, ", got " , prev_split)); |
101 | } |
102 | |
103 | int num_batch_items = splits_vec.size() - 1; |
104 | tensorflow::Tensor* ngrams_splits; |
105 | OP_REQUIRES_OK( |
106 | context, context->allocate_output(1, splits->shape(), &ngrams_splits)); |
107 | auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data(); |
108 | |
109 | // If there is no data or size, return an empty RT. |
110 | if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) { |
111 | tensorflow::Tensor* empty; |
112 | OP_REQUIRES_OK(context, |
113 | context->allocate_output(0, data->shape(), &empty)); |
114 | for (int i = 0; i <= num_batch_items; ++i) { |
115 | ngrams_splits_data[i] = 0; |
116 | } |
117 | return; |
118 | } |
119 | |
120 | ngrams_splits_data[0] = 0; |
121 | for (int i = 1; i <= num_batch_items; ++i) { |
122 | int length = splits_vec(i) - splits_vec(i - 1); |
123 | int num_ngrams = 0; |
124 | for (int ngram_width : ngram_widths_) { |
125 | auto ngrams_or = get_num_ngrams(length, ngram_width); |
126 | OP_REQUIRES_OK(context, ngrams_or.status()); |
127 | num_ngrams += ngrams_or.value(); |
128 | } |
129 | if (preserve_short_ && length > 0 && num_ngrams == 0) { |
130 | num_ngrams = 1; |
131 | } |
132 | ngrams_splits_data[i] = ngrams_splits_data[i - 1] + num_ngrams; |
133 | } |
134 | |
135 | tensorflow::Tensor* ngrams; |
136 | OP_REQUIRES_OK( |
137 | context, |
138 | context->allocate_output( |
139 | 0, TensorShape({ngrams_splits_data[num_batch_items]}), &ngrams)); |
140 | auto ngrams_data = ngrams->flat<tstring>().data(); |
141 | |
142 | for (int i = 0; i < num_batch_items; ++i) { |
143 | auto data_start = &input_data[splits_vec(i)]; |
144 | int output_start_idx = ngrams_splits_data[i]; |
145 | for (int ngram_width : ngram_widths_) { |
146 | auto output_start = &ngrams_data[output_start_idx]; |
147 | int length = splits_vec(i + 1) - splits_vec(i); |
148 | auto ngrams_or = get_num_ngrams(length, ngram_width); |
149 | OP_REQUIRES_OK(context, ngrams_or.status()); |
150 | int num_ngrams = ngrams_or.value(); |
151 | CreateNgrams(data_start, output_start, num_ngrams, ngram_width); |
152 | output_start_idx += num_ngrams; |
153 | } |
154 | // If we're preserving short sequences, check to see if no sequence was |
155 | // generated by comparing the current output start idx to the original |
156 | // one (ngram_splits_data). If no ngrams were generated, then they will |
157 | // be equal (since we increment output_start_idx by num_ngrams every |
158 | // time we create a set of ngrams.) |
159 | if (preserve_short_ && output_start_idx == ngrams_splits_data[i]) { |
160 | int data_length = splits_vec(i + 1) - splits_vec(i); |
161 | // One legitimate reason to not have any ngrams when preserve_short_ |
162 | // is true is if the sequence itself is empty. In that case, move on. |
163 | if (data_length == 0) { |
164 | continue; |
165 | } |
166 | // We don't have to worry about dynamic padding sizes here: if padding |
167 | // was dynamic, every sequence would have had sufficient padding to |
168 | // generate at least one ngram. |
169 | |
170 | // If reached here, pad_width should be > 0, pad_width_ = -1, |
171 | // which indicates max(ngram_widths) - 1 cannot be used here since |
172 | // ngram_width is not known. |
173 | OP_REQUIRES( |
174 | context, pad_width_ >= 0, |
175 | errors::InvalidArgument("Pad width should be >= 0 when " |
176 | "preserve_short_sequences is True and " |
177 | "ngram_widths are not provided, got " , |
178 | pad_width_)); |
179 | int ngram_width = data_length + 2 * pad_width_; |
180 | auto output_start = &ngrams_data[output_start_idx]; |
181 | int num_ngrams = 1; |
182 | CreateNgrams(data_start, output_start, num_ngrams, ngram_width); |
183 | } |
184 | } |
185 | } |
186 | |
187 | void CreateNgrams(const tstring* data, tstring* output, int num_ngrams, |
188 | int ngram_width) const { |
189 | for (int ngram_index = 0; ngram_index < num_ngrams; ++ngram_index) { |
190 | int pad_width = get_pad_width(ngram_width); |
191 | int left_padding = std::max(0, pad_width - ngram_index); |
192 | int right_padding = |
193 | std::max(0, pad_width - (num_ngrams - (ngram_index + 1))); |
194 | int num_tokens = ngram_width - (left_padding + right_padding); |
195 | int data_start_index = left_padding > 0 ? 0 : ngram_index - pad_width; |
196 | |
197 | // Calculate the total expected size of the ngram so we can reserve the |
198 | // correct amount of space in the string. |
199 | int ngram_size = 0; |
200 | // Size of the left padding. |
201 | ngram_size += left_padding * left_pad_.length(); |
202 | // Size of the tokens. |
203 | for (int n = 0; n < num_tokens; ++n) { |
204 | ngram_size += data[data_start_index + n].length(); |
205 | } |
206 | // Size of the right padding. |
207 | ngram_size += right_padding * right_pad_.length(); |
208 | // Size of the separators. |
209 | int num_separators = left_padding + right_padding + num_tokens - 1; |
210 | ngram_size += num_separators * separator_.length(); |
211 | |
212 | // Build the ngram. |
213 | tstring* ngram = &output[ngram_index]; |
214 | ngram->reserve(ngram_size); |
215 | for (int n = 0; n < left_padding; ++n) { |
216 | ngram->append(left_pad_); |
217 | ngram->append(separator_); |
218 | } |
219 | // Only output first num_tokens - 1 pairs of data and separator |
220 | for (int n = 0; n < num_tokens - 1; ++n) { |
221 | ngram->append(data[data_start_index + n]); |
222 | ngram->append(separator_); |
223 | } |
224 | // Handle case when there are no tokens or no right padding as these can |
225 | // result in consecutive separators. |
226 | if (num_tokens > 0) { |
227 | // If we have tokens, then output last and then pair each separator with |
228 | // the right padding that follows, to ensure ngram ends either with the |
229 | // token or with the right pad. |
230 | ngram->append(data[data_start_index + num_tokens - 1]); |
231 | for (int n = 0; n < right_padding; ++n) { |
232 | ngram->append(separator_); |
233 | ngram->append(right_pad_); |
234 | } |
235 | } else { |
236 | // If we don't have tokens, then the last item inserted into the ngram |
237 | // has been the separator from the left padding loop above. Hence, |
238 | // output right pad and separator and make sure to finish with a |
239 | // padding, not a separator. |
240 | for (int n = 0; n < right_padding - 1; ++n) { |
241 | ngram->append(right_pad_); |
242 | ngram->append(separator_); |
243 | } |
244 | ngram->append(right_pad_); |
245 | } |
246 | |
247 | // In debug mode only: validate that we've reserved enough space for the |
248 | // ngram. |
249 | DCHECK_EQ(ngram_size, ngram->size()); |
250 | } |
251 | } |
252 | |
253 | string separator_; |
254 | string left_pad_; |
255 | string right_pad_; |
256 | bool use_pad_; |
257 | bool extend_pad_; |
258 | bool preserve_short_; |
259 | |
260 | std::vector<int> ngram_widths_; |
261 | int pad_width_; |
262 | }; |
263 | |
264 | } // namespace |
265 | REGISTER_KERNEL_BUILDER(Name("StringNGrams" ) |
266 | .Device(tensorflow::DEVICE_CPU) |
267 | .TypeConstraint<int32>("Tsplits" ), |
268 | StringNGramsOp<int32>); |
269 | REGISTER_KERNEL_BUILDER(Name("StringNGrams" ) |
270 | .Device(tensorflow::DEVICE_CPU) |
271 | .TypeConstraint<int64_t>("Tsplits" ), |
272 | StringNGramsOp<int64_t>); |
273 | |
274 | } // namespace text |
275 | } // namespace tensorflow |
276 | |