1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Op that looks up items from a sparse tensor in an embedding matrix.
17// The sparse lookup tensor is represented by three individual tensors: lookup,
18// indices, and dense_shape. The representation assume that the corresponding
19// dense tensor would satisfy:
20// * dense.shape = dense_shape
21// * dense[tuple(indices[i])] = lookup[i]
22//
23// By convention, indices should be sorted.
24//
25// Options:
26// combiner: The reduction op (SUM, MEAN, SQRTN).
27// * SUM computes the weighted sum of the embedding results.
28// * MEAN is the weighted sum divided by the total weight.
29// * SQRTN is the weighted sum divided by the square root of the sum of the
30// squares of the weights.
31//
32// Input:
33// Tensor[0]: Ids to lookup, dim.size == 1, int32.
34// Tensor[1]: Indices, int32.
35// Tensor[2]: Dense shape, int32.
36// Tensor[3]: Weights to use for aggregation, float.
37// Tensor[4]: Params, a matrix of multi-dimensional items,
38// dim.size >= 2, float.
39//
40// Output:
41// A (dense) tensor representing the combined embeddings for the sparse ids.
42// For each row in the sparse tensor represented by (lookup, indices, shape)
43// the op looks up the embeddings for all ids in that row, multiplies them by
44// the corresponding weight, and combines these embeddings as specified in the
45// last dimension.
46//
47// Output.dim = [l0, ... , ln-1, e1, ..., em]
48// Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em]
49//
50// For instance, if params is a 10x20 matrix and ids, weights are:
51//
52// [0, 0]: id 1, weight 2.0
53// [0, 1]: id 3, weight 0.5
54// [1, 0]: id 0, weight 1.0
55// [2, 3]: id 1, weight 3.0
56//
57// with combiner=MEAN, then the output will be a (3, 20) tensor where:
58//
59// output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
60// output[1, :] = (params[0, :] * 1.0) / 1.0
61// output[2, :] = (params[1, :] * 3.0) / 3.0
62//
63// When indices are out of bound, the op will not succeed.
64
65#include <stdint.h>
66
67#include <algorithm>
68#include <cmath>
69
70#include "tensorflow/lite/c/builtin_op_data.h"
71#include "tensorflow/lite/c/common.h"
72#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
73#include "tensorflow/lite/kernels/internal/tensor_utils.h"
74#include "tensorflow/lite/kernels/kernel_util.h"
75#include "tensorflow/lite/util.h"
76
77namespace tflite {
78namespace ops {
79namespace builtin {
80
81namespace {
82
83TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
84 TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
85 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
86
87 const TfLiteTensor* ids;
88 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
89 TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
90 TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
91
92 const TfLiteTensor* indices;
93 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
94 TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
95 TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
96
97 const TfLiteTensor* shape;
98 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &shape));
99 TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
100 TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
101
102 const TfLiteTensor* weights;
103 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
104 TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
105 TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
106
107 TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
108 SizeOfDimension(ids, 0));
109 TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
110 SizeOfDimension(weights, 0));
111
112 const TfLiteTensor* value;
113 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
114 TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
115
116 // Mark the output as a dynamic tensor.
117 TfLiteTensor* output;
118 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
119 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
120 output->allocation_type = kTfLiteDynamic;
121
122 return kTfLiteOk;
123}
124
125void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements,
126 float current_total_weight,
127 float current_squares_weight, int embedding_size,
128 float* output) {
129 if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) {
130 float multiplier = 1.0;
131 switch (combiner) {
132 case kTfLiteCombinerTypeMean:
133 multiplier = current_total_weight;
134 break;
135 case kTfLiteCombinerTypeSqrtn:
136 multiplier = std::sqrt(current_squares_weight);
137 break;
138 default:
139 break;
140 }
141 for (int k = 0; k < embedding_size; k++) {
142 output[k] /= multiplier;
143 }
144 }
145}
146
147TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
148 auto* params =
149 reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
150 TfLiteTensor* output;
151 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
152 const TfLiteTensor* ids;
153 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &ids));
154 const TfLiteTensor* indices;
155 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &indices));
156 const TfLiteTensor* dense_shape;
157 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &dense_shape));
158 const TfLiteTensor* weights;
159 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 3, &weights));
160 const TfLiteTensor* value;
161 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 4, &value));
162 const size_t values_size = NumElements(value);
163
164 const int lookup_rank = SizeOfDimension(indices, 1);
165 const int embedding_rank = NumDimensions(value);
166 const int num_lookups = SizeOfDimension(ids, 0);
167 const int num_rows = SizeOfDimension(value, 0);
168
169 // The last dimension gets replaced by the embedding.
170 const int output_rank = (lookup_rank - 1) + (embedding_rank - 1);
171
172 // Make sure that the actual dense shape of the sparse tensor represented by
173 // (loopkup, indices, dense_shape) is consistent.
174 TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank);
175
176 // Resize output tensor.
177 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
178 TF_LITE_ENSURE(context, output_shape != nullptr);
179 int k = 0;
180 size_t embedding_size = 1;
181 size_t lookup_size = 1;
182 for (int i = 0; i < lookup_rank - 1; i++, k++) {
183 const size_t dim = dense_shape->data.i32[i];
184 TF_LITE_ENSURE_MSG(
185 context,
186 MultiplyAndCheckOverflow(lookup_size, dim, &lookup_size) == kTfLiteOk,
187 "Lookup size overflowed.");
188 output_shape->data[k] = dim;
189 }
190 for (int i = 1; i < embedding_rank; i++, k++) {
191 const size_t dim = SizeOfDimension(value, i);
192 TF_LITE_ENSURE_MSG(context,
193 MultiplyAndCheckOverflow(embedding_size, dim,
194 &embedding_size) == kTfLiteOk,
195 "Embedding size overflowed.");
196 output_shape->data[k] = dim;
197 }
198 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
199 const size_t output_size = lookup_size * embedding_size;
200 TfLiteTensorRealloc(output_size * sizeof(float), output);
201
202 float* output_ptr = GetTensorData<float>(output);
203 const float* weights_ptr = GetTensorData<float>(weights);
204 const float* value_ptr = GetTensorData<float>(value);
205 // Makes sure reallocation was successful.
206 TF_LITE_ENSURE(context, output_ptr != nullptr);
207
208 std::fill_n(output_ptr, output_size, 0.0f);
209
210 // Keep track of the current bucket for aggregation/combination.
211 int current_output_offset = 0;
212 float current_total_weight = 0.0;
213 float current_squares_weight = 0.0;
214 int num_elements = 0;
215
216 for (int i = 0; i < num_lookups; i++) {
217 int idx = ids->data.i32[i];
218 if (idx >= num_rows || idx < 0) {
219 TF_LITE_KERNEL_LOG(context,
220 "Embedding Lookup Sparse: index out of bounds. "
221 "Got %d, and bounds are [0, %d]",
222 idx, num_rows - 1);
223 return kTfLiteError;
224 }
225
226 // Check where we need to aggregate.
227 const int example_indices_offset = i * lookup_rank;
228 int output_bucket = 0;
229 int stride = 1;
230 for (int k = (lookup_rank - 1) - 1; k >= 0; k--) {
231 output_bucket += indices->data.i32[example_indices_offset + k] * stride;
232 stride *= dense_shape->data.i32[k];
233 }
234 const int output_offset = output_bucket * embedding_size;
235
236 // If we are in a new aggregation bucket and the combiner is not the sum,
237 // go back and finalize the result of the previous bucket.
238 if (output_offset != current_output_offset) {
239 FinalizeAggregation(params->combiner, num_elements, current_total_weight,
240 current_squares_weight, embedding_size,
241 &output_ptr[current_output_offset]);
242
243 // Track next bucket.
244 num_elements = 0;
245 current_total_weight = 0.0;
246 current_squares_weight = 0.0;
247 current_output_offset = output_offset;
248 }
249
250 // Add element to aggregation.
251 ++num_elements;
252 const int example_embedding_offset = idx * embedding_size;
253 const float w = weights_ptr[i];
254 current_squares_weight += w * w;
255 current_total_weight += w;
256 for (int k = 0; k < embedding_size; k++) {
257 // only index if indices are valid
258 if (current_output_offset + k < 0) continue;
259 if (current_output_offset + k >= output_size) continue;
260 if (example_embedding_offset + k < 0) continue;
261 if (example_embedding_offset + k >= values_size) continue;
262 output_ptr[current_output_offset + k] +=
263 value_ptr[example_embedding_offset + k] * w;
264 }
265 }
266
267 // Finalize last bucket.
268 FinalizeAggregation(params->combiner, num_elements, current_total_weight,
269 current_squares_weight, embedding_size,
270 &GetTensorData<float>(output)[current_output_offset]);
271
272 return kTfLiteOk;
273}
274
275} // namespace
276
277TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() {
278 static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
279 return &r;
280}
281
282} // namespace builtin
283} // namespace ops
284} // namespace tflite
285