1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Op that looks up items from a sparse tensor in an embedding matrix. |
17 | // The sparse lookup tensor is represented by three individual tensors: lookup, |
18 | // indices, and dense_shape. The representation assume that the corresponding |
19 | // dense tensor would satisfy: |
20 | // * dense.shape = dense_shape |
21 | // * dense[tuple(indices[i])] = lookup[i] |
22 | // |
23 | // By convention, indices should be sorted. |
24 | // |
25 | // Options: |
26 | // combiner: The reduction op (SUM, MEAN, SQRTN). |
27 | // * SUM computes the weighted sum of the embedding results. |
28 | // * MEAN is the weighted sum divided by the total weight. |
29 | // * SQRTN is the weighted sum divided by the square root of the sum of the |
30 | // squares of the weights. |
31 | // |
32 | // Input: |
33 | // Tensor[0]: Ids to lookup, dim.size == 1, int32. |
34 | // Tensor[1]: Indices, int32. |
35 | // Tensor[2]: Dense shape, int32. |
36 | // Tensor[3]: Weights to use for aggregation, float. |
37 | // Tensor[4]: Params, a matrix of multi-dimensional items, |
38 | // dim.size >= 2, float. |
39 | // |
40 | // Output: |
41 | // A (dense) tensor representing the combined embeddings for the sparse ids. |
42 | // For each row in the sparse tensor represented by (lookup, indices, shape) |
43 | // the op looks up the embeddings for all ids in that row, multiplies them by |
44 | // the corresponding weight, and combines these embeddings as specified in the |
45 | // last dimension. |
46 | // |
47 | // Output.dim = [l0, ... , ln-1, e1, ..., em] |
48 | // Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em] |
49 | // |
50 | // For instance, if params is a 10x20 matrix and ids, weights are: |
51 | // |
52 | // [0, 0]: id 1, weight 2.0 |
53 | // [0, 1]: id 3, weight 0.5 |
54 | // [1, 0]: id 0, weight 1.0 |
55 | // [2, 3]: id 1, weight 3.0 |
56 | // |
57 | // with combiner=MEAN, then the output will be a (3, 20) tensor where: |
58 | // |
59 | // output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) |
60 | // output[1, :] = (params[0, :] * 1.0) / 1.0 |
61 | // output[2, :] = (params[1, :] * 3.0) / 3.0 |
62 | // |
63 | // When indices are out of bound, the op will not succeed. |
64 | |
65 | #include <stdint.h> |
66 | |
67 | #include <algorithm> |
68 | #include <cmath> |
69 | |
70 | #include "tensorflow/lite/c/builtin_op_data.h" |
71 | #include "tensorflow/lite/c/common.h" |
72 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
73 | #include "tensorflow/lite/kernels/internal/tensor_utils.h" |
74 | #include "tensorflow/lite/kernels/kernel_util.h" |
75 | #include "tensorflow/lite/util.h" |
76 | |
77 | namespace tflite { |
78 | namespace ops { |
79 | namespace builtin { |
80 | |
81 | namespace { |
82 | |
83 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
84 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 5); |
85 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
86 | |
87 | const TfLiteTensor* ids; |
88 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids)); |
89 | TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1); |
90 | TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32); |
91 | |
92 | const TfLiteTensor* indices; |
93 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices)); |
94 | TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2); |
95 | TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32); |
96 | |
97 | const TfLiteTensor* shape; |
98 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &shape)); |
99 | TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1); |
100 | TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32); |
101 | |
102 | const TfLiteTensor* weights; |
103 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights)); |
104 | TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1); |
105 | TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32); |
106 | |
107 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), |
108 | SizeOfDimension(ids, 0)); |
109 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), |
110 | SizeOfDimension(weights, 0)); |
111 | |
112 | const TfLiteTensor* value; |
113 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value)); |
114 | TF_LITE_ENSURE(context, NumDimensions(value) >= 2); |
115 | |
116 | // Mark the output as a dynamic tensor. |
117 | TfLiteTensor* output; |
118 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
119 | TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32); |
120 | output->allocation_type = kTfLiteDynamic; |
121 | |
122 | return kTfLiteOk; |
123 | } |
124 | |
125 | void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements, |
126 | float current_total_weight, |
127 | float current_squares_weight, int embedding_size, |
128 | float* output) { |
129 | if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) { |
130 | float multiplier = 1.0; |
131 | switch (combiner) { |
132 | case kTfLiteCombinerTypeMean: |
133 | multiplier = current_total_weight; |
134 | break; |
135 | case kTfLiteCombinerTypeSqrtn: |
136 | multiplier = std::sqrt(current_squares_weight); |
137 | break; |
138 | default: |
139 | break; |
140 | } |
141 | for (int k = 0; k < embedding_size; k++) { |
142 | output[k] /= multiplier; |
143 | } |
144 | } |
145 | } |
146 | |
147 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
148 | auto* params = |
149 | reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data); |
150 | TfLiteTensor* output; |
151 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
152 | const TfLiteTensor* ids; |
153 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids)); |
154 | const TfLiteTensor* indices; |
155 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices)); |
156 | const TfLiteTensor* dense_shape; |
157 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &dense_shape)); |
158 | const TfLiteTensor* weights; |
159 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights)); |
160 | const TfLiteTensor* value; |
161 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value)); |
162 | const size_t values_size = NumElements(value); |
163 | |
164 | const int lookup_rank = SizeOfDimension(indices, 1); |
165 | const int embedding_rank = NumDimensions(value); |
166 | const int num_lookups = SizeOfDimension(ids, 0); |
167 | const int num_rows = SizeOfDimension(value, 0); |
168 | |
169 | // The last dimension gets replaced by the embedding. |
170 | const int output_rank = (lookup_rank - 1) + (embedding_rank - 1); |
171 | |
172 | // Make sure that the actual dense shape of the sparse tensor represented by |
173 | // (loopkup, indices, dense_shape) is consistent. |
174 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank); |
175 | |
176 | // Resize output tensor. |
177 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); |
178 | TF_LITE_ENSURE(context, output_shape != nullptr); |
179 | int k = 0; |
180 | size_t embedding_size = 1; |
181 | size_t lookup_size = 1; |
182 | for (int i = 0; i < lookup_rank - 1; i++, k++) { |
183 | const size_t dim = dense_shape->data.i32[i]; |
184 | TF_LITE_ENSURE_MSG( |
185 | context, |
186 | MultiplyAndCheckOverflow(lookup_size, dim, &lookup_size) == kTfLiteOk, |
187 | "Lookup size overflowed." ); |
188 | output_shape->data[k] = dim; |
189 | } |
190 | for (int i = 1; i < embedding_rank; i++, k++) { |
191 | const size_t dim = SizeOfDimension(value, i); |
192 | TF_LITE_ENSURE_MSG(context, |
193 | MultiplyAndCheckOverflow(embedding_size, dim, |
194 | &embedding_size) == kTfLiteOk, |
195 | "Embedding size overflowed." ); |
196 | output_shape->data[k] = dim; |
197 | } |
198 | TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); |
199 | const size_t output_size = lookup_size * embedding_size; |
200 | TfLiteTensorRealloc(output_size * sizeof(float), output); |
201 | |
202 | float* output_ptr = GetTensorData<float>(output); |
203 | const float* weights_ptr = GetTensorData<float>(weights); |
204 | const float* value_ptr = GetTensorData<float>(value); |
205 | // Makes sure reallocation was successful. |
206 | TF_LITE_ENSURE(context, output_ptr != nullptr); |
207 | |
208 | std::fill_n(output_ptr, output_size, 0.0f); |
209 | |
210 | // Keep track of the current bucket for aggregation/combination. |
211 | int current_output_offset = 0; |
212 | float current_total_weight = 0.0; |
213 | float current_squares_weight = 0.0; |
214 | int num_elements = 0; |
215 | |
216 | for (int i = 0; i < num_lookups; i++) { |
217 | int idx = ids->data.i32[i]; |
218 | if (idx >= num_rows || idx < 0) { |
219 | TF_LITE_KERNEL_LOG(context, |
220 | "Embedding Lookup Sparse: index out of bounds. " |
221 | "Got %d, and bounds are [0, %d]" , |
222 | idx, num_rows - 1); |
223 | return kTfLiteError; |
224 | } |
225 | |
226 | // Check where we need to aggregate. |
227 | const int example_indices_offset = i * lookup_rank; |
228 | int output_bucket = 0; |
229 | int stride = 1; |
230 | for (int k = (lookup_rank - 1) - 1; k >= 0; k--) { |
231 | output_bucket += indices->data.i32[example_indices_offset + k] * stride; |
232 | stride *= dense_shape->data.i32[k]; |
233 | } |
234 | const int output_offset = output_bucket * embedding_size; |
235 | |
236 | // If we are in a new aggregation bucket and the combiner is not the sum, |
237 | // go back and finalize the result of the previous bucket. |
238 | if (output_offset != current_output_offset) { |
239 | FinalizeAggregation(params->combiner, num_elements, current_total_weight, |
240 | current_squares_weight, embedding_size, |
241 | &output_ptr[current_output_offset]); |
242 | |
243 | // Track next bucket. |
244 | num_elements = 0; |
245 | current_total_weight = 0.0; |
246 | current_squares_weight = 0.0; |
247 | current_output_offset = output_offset; |
248 | } |
249 | |
250 | // Add element to aggregation. |
251 | ++num_elements; |
252 | const int example_embedding_offset = idx * embedding_size; |
253 | const float w = weights_ptr[i]; |
254 | current_squares_weight += w * w; |
255 | current_total_weight += w; |
256 | for (int k = 0; k < embedding_size; k++) { |
257 | // only index if indices are valid |
258 | if (current_output_offset + k < 0) continue; |
259 | if (current_output_offset + k >= output_size) continue; |
260 | if (example_embedding_offset + k < 0) continue; |
261 | if (example_embedding_offset + k >= values_size) continue; |
262 | output_ptr[current_output_offset + k] += |
263 | value_ptr[example_embedding_offset + k] * w; |
264 | } |
265 | } |
266 | |
267 | // Finalize last bucket. |
268 | FinalizeAggregation(params->combiner, num_elements, current_total_weight, |
269 | current_squares_weight, embedding_size, |
270 | &GetTensorData<float>(output)[current_output_offset]); |
271 | |
272 | return kTfLiteOk; |
273 | } |
274 | |
275 | } // namespace |
276 | |
277 | TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() { |
278 | static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; |
279 | return &r; |
280 | } |
281 | |
282 | } // namespace builtin |
283 | } // namespace ops |
284 | } // namespace tflite |
285 | |