1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Ops that looks up items from matrix. |
17 | // |
18 | // Input: |
19 | // Tensor[0]: Row number to lookup, dim.size == 1, int32 |
20 | // Tensor[1]: 2-dimensional matrix of multi-dimensional items |
21 | // dim.size >= 2, any data type. |
22 | // first dimension is row, second dimension is column. |
23 | // |
24 | // Output: |
25 | // Output.dim[0] == Tensor[0].dim[0], num of lookups |
26 | // Output.dim[1] == Tensor[1].dim[1], num of items per row |
27 | // Each item in output is a raw bytes copy of the corresponding item in input, |
28 | // or a dequantized value in the case of a uint8 input. |
29 | // When indices are out of bound, the ops will not succeed. |
30 | // |
31 | |
32 | #include <stdint.h> |
33 | |
34 | #include <cstring> |
35 | |
36 | #include "tensorflow/lite/c/common.h" |
37 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
38 | #include "tensorflow/lite/kernels/kernel_util.h" |
39 | |
40 | namespace tflite { |
41 | namespace ops { |
42 | namespace builtin { |
43 | namespace embedding_lookup { |
44 | |
45 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
46 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
47 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
48 | |
49 | const TfLiteTensor* lookup; |
50 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup)); |
51 | TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); |
52 | TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); |
53 | |
54 | const TfLiteTensor* value; |
55 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value)); |
56 | TF_LITE_ENSURE(context, NumDimensions(value) >= 2); |
57 | |
58 | TfLiteTensor* output; |
59 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
60 | TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); |
61 | |
62 | outputSize->data[0] = SizeOfDimension(lookup, 0); |
63 | outputSize->data[1] = SizeOfDimension(value, 1); |
64 | for (int i = 2; i < NumDimensions(value); i++) { |
65 | outputSize->data[i] = SizeOfDimension(value, i); |
66 | } |
67 | return context->ResizeTensor(context, output, outputSize); |
68 | } |
69 | |
70 | TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node, |
71 | const TfLiteTensor* lookup, const TfLiteTensor* value, |
72 | TfLiteTensor* output) { |
73 | const int row_size = SizeOfDimension(value, 0); |
74 | if (row_size == 0) { |
75 | // Propagate empty tensor if input is empty |
76 | return kTfLiteOk; |
77 | } |
78 | const int row_bytes = value->bytes / row_size; |
79 | |
80 | char* output_raw = GetTensorData<char>(output); |
81 | const char* value_raw = GetTensorData<char>(value); |
82 | const int32_t* lookup_data = GetTensorData<int32_t>(lookup); |
83 | for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { |
84 | int idx = lookup_data[i]; |
85 | if (idx >= row_size || idx < 0) { |
86 | TF_LITE_KERNEL_LOG(context, |
87 | "Embedding Lookup: index out of bounds. " |
88 | "Got %d, and bounds are [0, %d]" , |
89 | idx, row_size - 1); |
90 | return kTfLiteError; |
91 | } else { |
92 | std::memcpy(output_raw + i * row_bytes, value_raw + idx * row_bytes, |
93 | row_bytes); |
94 | } |
95 | } |
96 | |
97 | return kTfLiteOk; |
98 | } |
99 | |
100 | TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, |
101 | const TfLiteTensor* lookup, const TfLiteTensor* value, |
102 | TfLiteTensor* output) { |
103 | const int row_size = SizeOfDimension(value, 0); |
104 | const double scaling_factor = value->params.scale; |
105 | |
106 | // col_size after we flatten tensor into 2D. |
107 | int col_size = 1; |
108 | for (int i = 1; i < NumDimensions(value); i++) { |
109 | col_size *= SizeOfDimension(value, i); |
110 | } |
111 | |
112 | float* output_ptr = GetTensorData<float>(output); |
113 | const int8_t* value_ptr = GetTensorData<int8_t>(value); |
114 | const int32_t* lookup_data = GetTensorData<int32_t>(lookup); |
115 | |
116 | for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { |
117 | int idx = lookup_data[i]; |
118 | if (idx >= row_size || idx < 0) { |
119 | TF_LITE_KERNEL_LOG(context, |
120 | "Embedding Lookup: index out of bounds. " |
121 | "Got %d, and bounds are [0, %d]" , |
122 | idx, row_size - 1); |
123 | return kTfLiteError; |
124 | } else { |
125 | // Dequantize embedding values. |
126 | // TODO(alanchiao): refactor scalar multiply into separate function |
127 | // for ease of adding a neon equivalent if ever necessary. |
128 | for (int j = 0; j < col_size; j++) { |
129 | output_ptr[j + i * col_size] = |
130 | value_ptr[j + idx * col_size] * scaling_factor; |
131 | } |
132 | } |
133 | } |
134 | |
135 | return kTfLiteOk; |
136 | } |
137 | |
138 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
139 | const TfLiteTensor* lookup; |
140 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup)); |
141 | const TfLiteTensor* value; |
142 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &value)); |
143 | TfLiteTensor* output; |
144 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
145 | switch (value->type) { |
146 | case kTfLiteFloat32: |
147 | return EvalSimple(context, node, lookup, value, output); |
148 | case kTfLiteUInt8: |
149 | case kTfLiteInt8: |
150 | if (output->type == kTfLiteFloat32) { |
151 | return EvalHybrid(context, node, lookup, value, output); |
152 | } else { |
153 | return EvalSimple(context, node, lookup, value, output); |
154 | } |
155 | default: |
156 | TF_LITE_KERNEL_LOG(context, "Type not currently supported." ); |
157 | return kTfLiteError; |
158 | } |
159 | } |
160 | |
161 | } // namespace embedding_lookup |
162 | |
163 | TfLiteRegistration* Register_EMBEDDING_LOOKUP() { |
164 | static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare, |
165 | embedding_lookup::Eval}; |
166 | return &r; |
167 | } |
168 | |
169 | } // namespace builtin |
170 | } // namespace ops |
171 | } // namespace tflite |
172 | |