1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Generate a list of skip grams from an input.
17//
18// Options:
19// ngram_size: num of words for each output item.
20// max_skip_size: max num of words to skip.
21// The op generates ngrams when it is 0.
22// include_all_ngrams: include all ngrams with size up to ngram_size.
23//
24// Input:
25// A string tensor to generate n-grams.
26// Dim = {1}
27//
28// Output:
29// A list of strings, each of which contains ngram_size words.
30// Dim = {num_ngram}
31
32#include <ctype.h>
33
34#include <vector>
35
36#include "tensorflow/lite/c/builtin_op_data.h"
37#include "tensorflow/lite/c/common.h"
38#include "tensorflow/lite/kernels/internal/compatibility.h"
39#include "tensorflow/lite/kernels/kernel_util.h"
40#include "tensorflow/lite/string_util.h"
41
42namespace tflite {
43namespace ops {
44namespace builtin {
45
46namespace {
47
48TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
49 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
50 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
51
52 const TfLiteTensor* input_tensor;
53 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor));
54 TF_LITE_ENSURE_TYPES_EQ(context, input_tensor->type, kTfLiteString);
55 TfLiteTensor* output_tensor;
56 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor));
57 TF_LITE_ENSURE_TYPES_EQ(context, output_tensor->type, kTfLiteString);
58 return kTfLiteOk;
59}
60
61bool ShouldIncludeCurrentNgram(const TfLiteSkipGramParams* params, int size) {
62 if (size <= 0) {
63 return false;
64 }
65 if (params->include_all_ngrams) {
66 return size <= params->ngram_size;
67 } else {
68 return size == params->ngram_size;
69 }
70}
71
72bool ShouldStepInRecursion(const TfLiteSkipGramParams* params,
73 const std::vector<int>& stack, int stack_idx,
74 int num_words) {
75 // If current stack size and next word enumeration are within valid range.
76 if (stack_idx < params->ngram_size && stack[stack_idx] + 1 < num_words) {
77 // If this stack is empty, step in for first word enumeration.
78 if (stack_idx == 0) {
79 return true;
80 }
81 // If next word enumeration are within the range of max_skip_size.
82 // NOTE: equivalent to
83 // next_word_idx = stack[stack_idx] + 1
84 // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1
85 if (stack[stack_idx] - stack[stack_idx - 1] <= params->max_skip_size) {
86 return true;
87 }
88 }
89 return false;
90}
91
92TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
93 auto* params = reinterpret_cast<TfLiteSkipGramParams*>(node->builtin_data);
94
95 // Split sentence to words.
96 std::vector<StringRef> words;
97 const TfLiteTensor* input;
98 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
99 tflite::StringRef strref = tflite::GetString(input, 0);
100 int prev_idx = 0;
101 for (int i = 1; i < strref.len; i++) {
102 if (isspace(*(strref.str + i))) {
103 if (i > prev_idx && !isspace(*(strref.str + prev_idx))) {
104 words.push_back({strref.str + prev_idx, i - prev_idx});
105 }
106 prev_idx = i + 1;
107 }
108 }
109 if (strref.len > prev_idx) {
110 words.push_back({strref.str + prev_idx, strref.len - prev_idx});
111 }
112
113 // Generate n-grams recursively.
114 tflite::DynamicBuffer buf;
115 if (words.size() < params->ngram_size) {
116 buf.WriteToTensorAsVector(GetOutput(context, node, 0));
117 return kTfLiteOk;
118 }
119
120 // Stack stores the index of word used to generate ngram.
121 // The size of stack is the size of ngram.
122 std::vector<int> stack(params->ngram_size, 0);
123 // Stack index that indicates which depth the recursion is operating at.
124 int stack_idx = 1;
125 int num_words = words.size();
126
127 while (stack_idx >= 0) {
128 if (ShouldStepInRecursion(params, stack, stack_idx, num_words)) {
129 // When current depth can fill with a new word
130 // and the new word is within the max range to skip,
131 // fill this word to stack, recurse into next depth.
132 stack[stack_idx]++;
133 stack_idx++;
134 if (stack_idx < params->ngram_size) {
135 stack[stack_idx] = stack[stack_idx - 1];
136 }
137 } else {
138 if (ShouldIncludeCurrentNgram(params, stack_idx)) {
139 // Add n-gram to tensor buffer when the stack has filled with enough
140 // words to generate the ngram.
141 std::vector<StringRef> gram(stack_idx);
142 for (int i = 0; i < stack_idx; i++) {
143 gram[i] = words[stack[i]];
144 }
145 buf.AddJoinedString(gram, ' ');
146 }
147 // When current depth cannot fill with a valid new word,
148 // and not in last depth to generate ngram,
149 // step back to previous depth to iterate to next possible word.
150 stack_idx--;
151 }
152 }
153
154 buf.WriteToTensorAsVector(GetOutput(context, node, 0));
155 return kTfLiteOk;
156}
157} // namespace
158
159TfLiteRegistration* Register_SKIP_GRAM() {
160 static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
161 return &r;
162}
163
164} // namespace builtin
165} // namespace ops
166} // namespace tflite
167