1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Ops that looks up items from matrix.
17//
18// Input:
19// Tensor[0]: Row number to lookup, dim.size == 1, int32
20// Tensor[1]: 2-dimensional matrix of multi-dimensional items
21// dim.size >= 2, any data type.
22// first dimension is row, second dimension is column.
23//
24// Output:
25// Output.dim[0] == Tensor[0].dim[0], num of lookups
26// Output.dim[1] == Tensor[1].dim[1], num of items per row
27// Each item in output is a raw bytes copy of the corresponding item in input,
28// or a dequantized value in the case of a uint8 input.
29// When indices are out of bound, the ops will not succeed.
30//
31
32#include <stdint.h>
33
34#include <cstring>
35
36#include "tensorflow/lite/c/common.h"
37#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
38#include "tensorflow/lite/kernels/kernel_util.h"
39
40namespace tflite {
41namespace ops {
42namespace builtin {
43namespace embedding_lookup {
44
45TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
46 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
47 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
48
49 const TfLiteTensor* lookup;
50 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
51 TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
52 TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
53
54 const TfLiteTensor* value;
55 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
56 TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
57
58 TfLiteTensor* output;
59 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
60 TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
61
62 outputSize->data[0] = SizeOfDimension(lookup, 0);
63 outputSize->data[1] = SizeOfDimension(value, 1);
64 for (int i = 2; i < NumDimensions(value); i++) {
65 outputSize->data[i] = SizeOfDimension(value, i);
66 }
67 return context->ResizeTensor(context, output, outputSize);
68}
69
70TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node,
71 const TfLiteTensor* lookup, const TfLiteTensor* value,
72 TfLiteTensor* output) {
73 const int row_size = SizeOfDimension(value, 0);
74 if (row_size == 0) {
75 // Propagate empty tensor if input is empty
76 return kTfLiteOk;
77 }
78 const int row_bytes = value->bytes / row_size;
79
80 char* output_raw = GetTensorData<char>(output);
81 const char* value_raw = GetTensorData<char>(value);
82 const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
83 for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
84 int idx = lookup_data[i];
85 if (idx >= row_size || idx < 0) {
86 TF_LITE_KERNEL_LOG(context,
87 "Embedding Lookup: index out of bounds. "
88 "Got %d, and bounds are [0, %d]",
89 idx, row_size - 1);
90 return kTfLiteError;
91 } else {
92 std::memcpy(output_raw + i * row_bytes, value_raw + idx * row_bytes,
93 row_bytes);
94 }
95 }
96
97 return kTfLiteOk;
98}
99
100TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
101 const TfLiteTensor* lookup, const TfLiteTensor* value,
102 TfLiteTensor* output) {
103 const int row_size = SizeOfDimension(value, 0);
104 const double scaling_factor = value->params.scale;
105
106 // col_size after we flatten tensor into 2D.
107 int col_size = 1;
108 for (int i = 1; i < NumDimensions(value); i++) {
109 col_size *= SizeOfDimension(value, i);
110 }
111
112 float* output_ptr = GetTensorData<float>(output);
113 const int8_t* value_ptr = GetTensorData<int8_t>(value);
114 const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
115
116 for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
117 int idx = lookup_data[i];
118 if (idx >= row_size || idx < 0) {
119 TF_LITE_KERNEL_LOG(context,
120 "Embedding Lookup: index out of bounds. "
121 "Got %d, and bounds are [0, %d]",
122 idx, row_size - 1);
123 return kTfLiteError;
124 } else {
125 // Dequantize embedding values.
126 // TODO(alanchiao): refactor scalar multiply into separate function
127 // for ease of adding a neon equivalent if ever necessary.
128 for (int j = 0; j < col_size; j++) {
129 output_ptr[j + i * col_size] =
130 value_ptr[j + idx * col_size] * scaling_factor;
131 }
132 }
133 }
134
135 return kTfLiteOk;
136}
137
138TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
139 const TfLiteTensor* lookup;
140 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
141 const TfLiteTensor* value;
142 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value));
143 TfLiteTensor* output;
144 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
145 switch (value->type) {
146 case kTfLiteFloat32:
147 return EvalSimple(context, node, lookup, value, output);
148 case kTfLiteUInt8:
149 case kTfLiteInt8:
150 if (output->type == kTfLiteFloat32) {
151 return EvalHybrid(context, node, lookup, value, output);
152 } else {
153 return EvalSimple(context, node, lookup, value, output);
154 }
155 default:
156 TF_LITE_KERNEL_LOG(context, "Type not currently supported.");
157 return kTfLiteError;
158 }
159}
160
161} // namespace embedding_lookup
162
163TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
164 static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare,
165 embedding_lookup::Eval};
166 return &r;
167}
168
169} // namespace builtin
170} // namespace ops
171} // namespace tflite
172