1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Op that looks up items from hashtable. |
17 | // |
18 | // Input: |
19 | // Tensor[0]: Hash key to lookup, dim.size == 1, int32 |
20 | // Tensor[1]: Key of hashtable, dim.size == 1, int32 |
21 | // *MUST* be sorted in ascending order. |
22 | // Tensor[2]: Value of hashtable, dim.size >= 1 |
23 | // Tensor[1].Dim[0] == Tensor[2].Dim[0] |
24 | // |
25 | // Output: |
26 | // Output[0].dim[0] == Tensor[0].dim[0], num of lookups |
27 | // Each item in output is a raw bytes copy of corresponding item in input. |
28 | // When key does not exist in hashtable, the returned bytes are all 0s. |
29 | // |
30 | // Output[1].dim = { Tensor[0].dim[0] }, num of lookups |
31 | // Each item indicates whether the corresponding lookup has a returned value. |
32 | // 0 for missing key, 1 for found key. |
33 | |
34 | #include <stdint.h> |
35 | |
36 | #include <cstdlib> |
37 | #include <cstring> |
38 | |
39 | #include "tensorflow/lite/c/common.h" |
40 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
41 | #include "tensorflow/lite/kernels/kernel_util.h" |
42 | #include "tensorflow/lite/string_util.h" |
43 | |
44 | namespace tflite { |
45 | namespace ops { |
46 | namespace builtin { |
47 | |
48 | namespace { |
49 | |
50 | int greater(const void* a, const void* b) { |
51 | return *static_cast<const int*>(a) - *static_cast<const int*>(b); |
52 | } |
53 | |
54 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
55 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); |
56 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); |
57 | |
58 | const TfLiteTensor* lookup; |
59 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup)); |
60 | TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); |
61 | TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); |
62 | |
63 | const TfLiteTensor* key; |
64 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key)); |
65 | TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1); |
66 | TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32); |
67 | |
68 | const TfLiteTensor* value; |
69 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value)); |
70 | TF_LITE_ENSURE(context, NumDimensions(value) >= 1); |
71 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0), |
72 | SizeOfDimension(value, 0)); |
73 | if (value->type == kTfLiteString) { |
74 | TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1); |
75 | } |
76 | |
77 | TfLiteTensor* hits; |
78 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits)); |
79 | TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8); |
80 | TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1); |
81 | hitSize->data[0] = SizeOfDimension(lookup, 0); |
82 | |
83 | TfLiteTensor* output; |
84 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
85 | TF_LITE_ENSURE_EQ(context, value->type, output->type); |
86 | |
87 | TfLiteStatus status = kTfLiteOk; |
88 | if (output->type != kTfLiteString) { |
89 | TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); |
90 | outputSize->data[0] = SizeOfDimension(lookup, 0); |
91 | for (int i = 1; i < NumDimensions(value); i++) { |
92 | outputSize->data[i] = SizeOfDimension(value, i); |
93 | } |
94 | status = context->ResizeTensor(context, output, outputSize); |
95 | } |
96 | if (context->ResizeTensor(context, hits, hitSize) != kTfLiteOk) { |
97 | status = kTfLiteError; |
98 | } |
99 | return status; |
100 | } |
101 | |
102 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
103 | TfLiteTensor* output; |
104 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
105 | TfLiteTensor* hits; |
106 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits)); |
107 | const TfLiteTensor* lookup; |
108 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup)); |
109 | const TfLiteTensor* key; |
110 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key)); |
111 | const TfLiteTensor* value; |
112 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value)); |
113 | |
114 | const int num_rows = SizeOfDimension(value, 0); |
115 | TF_LITE_ENSURE(context, num_rows != 0); |
116 | const int row_bytes = value->bytes / num_rows; |
117 | void* pointer = nullptr; |
118 | DynamicBuffer buf; |
119 | |
120 | for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { |
121 | int idx = -1; |
122 | pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows, |
123 | sizeof(int32_t), greater); |
124 | if (pointer != nullptr) { |
125 | idx = (reinterpret_cast<char*>(pointer) - (key->data.raw)) / |
126 | sizeof(int32_t); |
127 | } |
128 | |
129 | if (idx >= num_rows || idx < 0) { |
130 | if (output->type == kTfLiteString) { |
131 | buf.AddString(nullptr, 0); |
132 | } else { |
133 | memset(output->data.raw + i * row_bytes, 0, row_bytes); |
134 | } |
135 | hits->data.uint8[i] = 0; |
136 | } else { |
137 | if (output->type == kTfLiteString) { |
138 | buf.AddString(GetString(value, idx)); |
139 | } else { |
140 | memcpy(output->data.raw + i * row_bytes, |
141 | value->data.raw + idx * row_bytes, row_bytes); |
142 | } |
143 | hits->data.uint8[i] = 1; |
144 | } |
145 | } |
146 | if (output->type == kTfLiteString) { |
147 | buf.WriteToTensorAsVector(output); |
148 | } |
149 | |
150 | return kTfLiteOk; |
151 | } |
152 | } // namespace |
153 | |
154 | TfLiteRegistration* Register_HASHTABLE_LOOKUP() { |
155 | static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; |
156 | return &r; |
157 | } |
158 | |
159 | } // namespace builtin |
160 | } // namespace ops |
161 | } // namespace tflite |
162 | |