1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <locale>
18#include <string>
19
20#include "absl/strings/ascii.h"
21#include "absl/strings/str_cat.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/op_requires.h"
24#include "tensorflow/core/platform/errors.h"
25#include "tensorflow/core/platform/types.h"
26
27namespace tensorflow {
28namespace text {
29
30namespace {
31template <typename SPLITS_TYPE>
32class StringNGramsOp : public tensorflow::OpKernel {
33 public:
34 explicit StringNGramsOp(tensorflow::OpKernelConstruction* context)
35 : tensorflow::OpKernel(context) {
36 OP_REQUIRES_OK(context, context->GetAttr("separator", &separator_));
37 OP_REQUIRES_OK(context, context->GetAttr("ngram_widths", &ngram_widths_));
38 OP_REQUIRES_OK(context, context->GetAttr("left_pad", &left_pad_));
39 OP_REQUIRES_OK(context, context->GetAttr("right_pad", &right_pad_));
40 OP_REQUIRES_OK(context, context->GetAttr("pad_width", &pad_width_));
41 OP_REQUIRES_OK(context, context->GetAttr("preserve_short_sequences",
42 &preserve_short_));
43 }
44
45 int get_pad_width(const int ngram_width) const {
46 // Ngrams can be padded with either a fixed pad width or a dynamic pad
47 // width depending on the 'pad_width' arg, but in no case should the padding
48 // ever be wider than 'ngram_width' - 1.
49 return std::min(pad_width_ < 0 ? ngram_width - 1 : pad_width_,
50 ngram_width - 1);
51 }
52
53 StatusOr<int> get_num_ngrams(const int length, const int ngram_width) const {
54 int64 limit = kint32max;
55 int pad_width = get_pad_width(ngram_width);
56 if (pad_width > limit / 2 - length) {
57 return errors::InvalidArgument(
58 "Pad width could lead to integer overflow, got pad_width = ",
59 pad_width);
60 }
61 return std::max(0, ((length + 2 * pad_width) - ngram_width) + 1);
62 }
63
64 void Compute(tensorflow::OpKernelContext* context) override {
65 for (int ngram_width : ngram_widths_) {
66 OP_REQUIRES(
67 context, ngram_width > 0,
68 errors::InvalidArgument("ngram_widths must contain positive values"));
69 }
70
71 const tensorflow::Tensor* data;
72 OP_REQUIRES_OK(context, context->input("data", &data));
73 const auto& input_data = data->flat<tstring>().data();
74
75 const tensorflow::Tensor* splits;
76 OP_REQUIRES_OK(context, context->input("data_splits", &splits));
77 const auto& splits_vec = splits->flat<SPLITS_TYPE>();
78
79 // Validate that the splits are valid indices into data, only if there are
80 // splits specified.
81 const int input_data_size = data->flat<tstring>().size();
82 const int splits_vec_size = splits_vec.size();
83 if (splits_vec_size > 0) {
84 int prev_split = splits_vec(0);
85 OP_REQUIRES(context, prev_split == 0,
86 errors::InvalidArgument("First split value must be 0, got ",
87 prev_split));
88 for (int i = 1; i < splits_vec_size; ++i) {
89 bool valid_splits = splits_vec(i) >= prev_split;
90 valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
91 OP_REQUIRES(context, valid_splits,
92 errors::InvalidArgument(
93 "Invalid split value ", splits_vec(i), ", must be in [",
94 prev_split, ", ", input_data_size, "]"));
95 prev_split = splits_vec(i);
96 }
97 OP_REQUIRES(context, prev_split == input_data_size,
98 errors::InvalidArgument(
99 "Last split value must be data size. Expected ",
100 input_data_size, ", got ", prev_split));
101 }
102
103 int num_batch_items = splits_vec.size() - 1;
104 tensorflow::Tensor* ngrams_splits;
105 OP_REQUIRES_OK(
106 context, context->allocate_output(1, splits->shape(), &ngrams_splits));
107 auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data();
108
109 // If there is no data or size, return an empty RT.
110 if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) {
111 tensorflow::Tensor* empty;
112 OP_REQUIRES_OK(context,
113 context->allocate_output(0, data->shape(), &empty));
114 for (int i = 0; i <= num_batch_items; ++i) {
115 ngrams_splits_data[i] = 0;
116 }
117 return;
118 }
119
120 ngrams_splits_data[0] = 0;
121 for (int i = 1; i <= num_batch_items; ++i) {
122 int length = splits_vec(i) - splits_vec(i - 1);
123 int num_ngrams = 0;
124 for (int ngram_width : ngram_widths_) {
125 auto ngrams_or = get_num_ngrams(length, ngram_width);
126 OP_REQUIRES_OK(context, ngrams_or.status());
127 num_ngrams += ngrams_or.value();
128 }
129 if (preserve_short_ && length > 0 && num_ngrams == 0) {
130 num_ngrams = 1;
131 }
132 ngrams_splits_data[i] = ngrams_splits_data[i - 1] + num_ngrams;
133 }
134
135 tensorflow::Tensor* ngrams;
136 OP_REQUIRES_OK(
137 context,
138 context->allocate_output(
139 0, TensorShape({ngrams_splits_data[num_batch_items]}), &ngrams));
140 auto ngrams_data = ngrams->flat<tstring>().data();
141
142 for (int i = 0; i < num_batch_items; ++i) {
143 auto data_start = &input_data[splits_vec(i)];
144 int output_start_idx = ngrams_splits_data[i];
145 for (int ngram_width : ngram_widths_) {
146 auto output_start = &ngrams_data[output_start_idx];
147 int length = splits_vec(i + 1) - splits_vec(i);
148 auto ngrams_or = get_num_ngrams(length, ngram_width);
149 OP_REQUIRES_OK(context, ngrams_or.status());
150 int num_ngrams = ngrams_or.value();
151 CreateNgrams(data_start, output_start, num_ngrams, ngram_width);
152 output_start_idx += num_ngrams;
153 }
154 // If we're preserving short sequences, check to see if no sequence was
155 // generated by comparing the current output start idx to the original
156 // one (ngram_splits_data). If no ngrams were generated, then they will
157 // be equal (since we increment output_start_idx by num_ngrams every
158 // time we create a set of ngrams.)
159 if (preserve_short_ && output_start_idx == ngrams_splits_data[i]) {
160 int data_length = splits_vec(i + 1) - splits_vec(i);
161 // One legitimate reason to not have any ngrams when preserve_short_
162 // is true is if the sequence itself is empty. In that case, move on.
163 if (data_length == 0) {
164 continue;
165 }
166 // We don't have to worry about dynamic padding sizes here: if padding
167 // was dynamic, every sequence would have had sufficient padding to
168 // generate at least one ngram.
169
170 // If reached here, pad_width should be > 0, pad_width_ = -1,
171 // which indicates max(ngram_widths) - 1 cannot be used here since
172 // ngram_width is not known.
173 OP_REQUIRES(
174 context, pad_width_ >= 0,
175 errors::InvalidArgument("Pad width should be >= 0 when "
176 "preserve_short_sequences is True and "
177 "ngram_widths are not provided, got ",
178 pad_width_));
179 int ngram_width = data_length + 2 * pad_width_;
180 auto output_start = &ngrams_data[output_start_idx];
181 int num_ngrams = 1;
182 CreateNgrams(data_start, output_start, num_ngrams, ngram_width);
183 }
184 }
185 }
186
187 void CreateNgrams(const tstring* data, tstring* output, int num_ngrams,
188 int ngram_width) const {
189 for (int ngram_index = 0; ngram_index < num_ngrams; ++ngram_index) {
190 int pad_width = get_pad_width(ngram_width);
191 int left_padding = std::max(0, pad_width - ngram_index);
192 int right_padding =
193 std::max(0, pad_width - (num_ngrams - (ngram_index + 1)));
194 int num_tokens = ngram_width - (left_padding + right_padding);
195 int data_start_index = left_padding > 0 ? 0 : ngram_index - pad_width;
196
197 // Calculate the total expected size of the ngram so we can reserve the
198 // correct amount of space in the string.
199 int ngram_size = 0;
200 // Size of the left padding.
201 ngram_size += left_padding * left_pad_.length();
202 // Size of the tokens.
203 for (int n = 0; n < num_tokens; ++n) {
204 ngram_size += data[data_start_index + n].length();
205 }
206 // Size of the right padding.
207 ngram_size += right_padding * right_pad_.length();
208 // Size of the separators.
209 int num_separators = left_padding + right_padding + num_tokens - 1;
210 ngram_size += num_separators * separator_.length();
211
212 // Build the ngram.
213 tstring* ngram = &output[ngram_index];
214 ngram->reserve(ngram_size);
215 for (int n = 0; n < left_padding; ++n) {
216 ngram->append(left_pad_);
217 ngram->append(separator_);
218 }
219 // Only output first num_tokens - 1 pairs of data and separator
220 for (int n = 0; n < num_tokens - 1; ++n) {
221 ngram->append(data[data_start_index + n]);
222 ngram->append(separator_);
223 }
224 // Handle case when there are no tokens or no right padding as these can
225 // result in consecutive separators.
226 if (num_tokens > 0) {
227 // If we have tokens, then output last and then pair each separator with
228 // the right padding that follows, to ensure ngram ends either with the
229 // token or with the right pad.
230 ngram->append(data[data_start_index + num_tokens - 1]);
231 for (int n = 0; n < right_padding; ++n) {
232 ngram->append(separator_);
233 ngram->append(right_pad_);
234 }
235 } else {
236 // If we don't have tokens, then the last item inserted into the ngram
237 // has been the separator from the left padding loop above. Hence,
238 // output right pad and separator and make sure to finish with a
239 // padding, not a separator.
240 for (int n = 0; n < right_padding - 1; ++n) {
241 ngram->append(right_pad_);
242 ngram->append(separator_);
243 }
244 ngram->append(right_pad_);
245 }
246
247 // In debug mode only: validate that we've reserved enough space for the
248 // ngram.
249 DCHECK_EQ(ngram_size, ngram->size());
250 }
251 }
252
253 string separator_;
254 string left_pad_;
255 string right_pad_;
256 bool use_pad_;
257 bool extend_pad_;
258 bool preserve_short_;
259
260 std::vector<int> ngram_widths_;
261 int pad_width_;
262};
263
264} // namespace
265REGISTER_KERNEL_BUILDER(Name("StringNGrams")
266 .Device(tensorflow::DEVICE_CPU)
267 .TypeConstraint<int32>("Tsplits"),
268 StringNGramsOp<int32>);
269REGISTER_KERNEL_BUILDER(Name("StringNGrams")
270 .Device(tensorflow::DEVICE_CPU)
271 .TypeConstraint<int64_t>("Tsplits"),
272 StringNGramsOp<int64_t>);
273
274} // namespace text
275} // namespace tensorflow
276