1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Generate a list of skip grams from an input. |
17 | // |
18 | // Options: |
19 | // ngram_size: num of words for each output item. |
20 | // max_skip_size: max num of words to skip. |
21 | // The op generates ngrams when it is 0. |
22 | // include_all_ngrams: include all ngrams with size up to ngram_size. |
23 | // |
24 | // Input: |
25 | // A string tensor to generate n-grams. |
26 | // Dim = {1} |
27 | // |
28 | // Output: |
29 | // A list of strings, each of which contains ngram_size words. |
30 | // Dim = {num_ngram} |
31 | |
32 | #include <ctype.h> |
33 | |
34 | #include <vector> |
35 | |
36 | #include "tensorflow/lite/c/builtin_op_data.h" |
37 | #include "tensorflow/lite/c/common.h" |
38 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
39 | #include "tensorflow/lite/kernels/kernel_util.h" |
40 | #include "tensorflow/lite/string_util.h" |
41 | |
42 | namespace tflite { |
43 | namespace ops { |
44 | namespace builtin { |
45 | |
46 | namespace { |
47 | |
48 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
49 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
50 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
51 | |
52 | const TfLiteTensor* input_tensor; |
53 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input_tensor)); |
54 | TF_LITE_ENSURE_TYPES_EQ(context, input_tensor->type, kTfLiteString); |
55 | TfLiteTensor* output_tensor; |
56 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output_tensor)); |
57 | TF_LITE_ENSURE_TYPES_EQ(context, output_tensor->type, kTfLiteString); |
58 | return kTfLiteOk; |
59 | } |
60 | |
61 | bool ShouldIncludeCurrentNgram(const TfLiteSkipGramParams* params, int size) { |
62 | if (size <= 0) { |
63 | return false; |
64 | } |
65 | if (params->include_all_ngrams) { |
66 | return size <= params->ngram_size; |
67 | } else { |
68 | return size == params->ngram_size; |
69 | } |
70 | } |
71 | |
72 | bool ShouldStepInRecursion(const TfLiteSkipGramParams* params, |
73 | const std::vector<int>& stack, int stack_idx, |
74 | int num_words) { |
75 | // If current stack size and next word enumeration are within valid range. |
76 | if (stack_idx < params->ngram_size && stack[stack_idx] + 1 < num_words) { |
77 | // If this stack is empty, step in for first word enumeration. |
78 | if (stack_idx == 0) { |
79 | return true; |
80 | } |
81 | // If next word enumeration are within the range of max_skip_size. |
82 | // NOTE: equivalent to |
83 | // next_word_idx = stack[stack_idx] + 1 |
84 | // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1 |
85 | if (stack[stack_idx] - stack[stack_idx - 1] <= params->max_skip_size) { |
86 | return true; |
87 | } |
88 | } |
89 | return false; |
90 | } |
91 | |
92 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
93 | auto* params = reinterpret_cast<TfLiteSkipGramParams*>(node->builtin_data); |
94 | |
95 | // Split sentence to words. |
96 | std::vector<StringRef> words; |
97 | const TfLiteTensor* input; |
98 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
99 | tflite::StringRef strref = tflite::GetString(input, 0); |
100 | int prev_idx = 0; |
101 | for (int i = 1; i < strref.len; i++) { |
102 | if (isspace(*(strref.str + i))) { |
103 | if (i > prev_idx && !isspace(*(strref.str + prev_idx))) { |
104 | words.push_back({strref.str + prev_idx, i - prev_idx}); |
105 | } |
106 | prev_idx = i + 1; |
107 | } |
108 | } |
109 | if (strref.len > prev_idx) { |
110 | words.push_back({strref.str + prev_idx, strref.len - prev_idx}); |
111 | } |
112 | |
113 | // Generate n-grams recursively. |
114 | tflite::DynamicBuffer buf; |
115 | if (words.size() < params->ngram_size) { |
116 | buf.WriteToTensorAsVector(GetOutput(context, node, 0)); |
117 | return kTfLiteOk; |
118 | } |
119 | |
120 | // Stack stores the index of word used to generate ngram. |
121 | // The size of stack is the size of ngram. |
122 | std::vector<int> stack(params->ngram_size, 0); |
123 | // Stack index that indicates which depth the recursion is operating at. |
124 | int stack_idx = 1; |
125 | int num_words = words.size(); |
126 | |
127 | while (stack_idx >= 0) { |
128 | if (ShouldStepInRecursion(params, stack, stack_idx, num_words)) { |
129 | // When current depth can fill with a new word |
130 | // and the new word is within the max range to skip, |
131 | // fill this word to stack, recurse into next depth. |
132 | stack[stack_idx]++; |
133 | stack_idx++; |
134 | if (stack_idx < params->ngram_size) { |
135 | stack[stack_idx] = stack[stack_idx - 1]; |
136 | } |
137 | } else { |
138 | if (ShouldIncludeCurrentNgram(params, stack_idx)) { |
139 | // Add n-gram to tensor buffer when the stack has filled with enough |
140 | // words to generate the ngram. |
141 | std::vector<StringRef> gram(stack_idx); |
142 | for (int i = 0; i < stack_idx; i++) { |
143 | gram[i] = words[stack[i]]; |
144 | } |
145 | buf.AddJoinedString(gram, ' '); |
146 | } |
147 | // When current depth cannot fill with a valid new word, |
148 | // and not in last depth to generate ngram, |
149 | // step back to previous depth to iterate to next possible word. |
150 | stack_idx--; |
151 | } |
152 | } |
153 | |
154 | buf.WriteToTensorAsVector(GetOutput(context, node, 0)); |
155 | return kTfLiteOk; |
156 | } |
157 | } // namespace |
158 | |
159 | TfLiteRegistration* Register_SKIP_GRAM() { |
160 | static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; |
161 | return &r; |
162 | } |
163 | |
164 | } // namespace builtin |
165 | } // namespace ops |
166 | } // namespace tflite |
167 | |