1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Op that looks up items from hashtable.
17//
18// Input:
19// Tensor[0]: Hash key to lookup, dim.size == 1, int32
20// Tensor[1]: Key of hashtable, dim.size == 1, int32
21// *MUST* be sorted in ascending order.
22// Tensor[2]: Value of hashtable, dim.size >= 1
23// Tensor[1].Dim[0] == Tensor[2].Dim[0]
24//
25// Output:
26// Output[0].dim[0] == Tensor[0].dim[0], num of lookups
27// Each item in output is a raw bytes copy of corresponding item in input.
28// When key does not exist in hashtable, the returned bytes are all 0s.
29//
30// Output[1].dim = { Tensor[0].dim[0] }, num of lookups
31// Each item indicates whether the corresponding lookup has a returned value.
32// 0 for missing key, 1 for found key.
33
34#include <stdint.h>
35
36#include <cstdlib>
37#include <cstring>
38
39#include "tensorflow/lite/c/common.h"
40#include "tensorflow/lite/kernels/internal/compatibility.h"
41#include "tensorflow/lite/kernels/kernel_util.h"
42#include "tensorflow/lite/string_util.h"
43
44namespace tflite {
45namespace ops {
46namespace builtin {
47
48namespace {
49
50int greater(const void* a, const void* b) {
51 return *static_cast<const int*>(a) - *static_cast<const int*>(b);
52}
53
54TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
55 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
56 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
57
58 const TfLiteTensor* lookup;
59 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
60 TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
61 TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
62
63 const TfLiteTensor* key;
64 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
65 TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1);
66 TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32);
67
68 const TfLiteTensor* value;
69 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
70 TF_LITE_ENSURE(context, NumDimensions(value) >= 1);
71 TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0),
72 SizeOfDimension(value, 0));
73 if (value->type == kTfLiteString) {
74 TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1);
75 }
76
77 TfLiteTensor* hits;
78 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
79 TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8);
80 TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1);
81 hitSize->data[0] = SizeOfDimension(lookup, 0);
82
83 TfLiteTensor* output;
84 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
85 TF_LITE_ENSURE_EQ(context, value->type, output->type);
86
87 TfLiteStatus status = kTfLiteOk;
88 if (output->type != kTfLiteString) {
89 TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
90 outputSize->data[0] = SizeOfDimension(lookup, 0);
91 for (int i = 1; i < NumDimensions(value); i++) {
92 outputSize->data[i] = SizeOfDimension(value, i);
93 }
94 status = context->ResizeTensor(context, output, outputSize);
95 }
96 if (context->ResizeTensor(context, hits, hitSize) != kTfLiteOk) {
97 status = kTfLiteError;
98 }
99 return status;
100}
101
102TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
103 TfLiteTensor* output;
104 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
105 TfLiteTensor* hits;
106 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &hits));
107 const TfLiteTensor* lookup;
108 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &lookup));
109 const TfLiteTensor* key;
110 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &key));
111 const TfLiteTensor* value;
112 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &value));
113
114 const int num_rows = SizeOfDimension(value, 0);
115 TF_LITE_ENSURE(context, num_rows != 0);
116 const int row_bytes = value->bytes / num_rows;
117 void* pointer = nullptr;
118 DynamicBuffer buf;
119
120 for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
121 int idx = -1;
122 pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows,
123 sizeof(int32_t), greater);
124 if (pointer != nullptr) {
125 idx = (reinterpret_cast<char*>(pointer) - (key->data.raw)) /
126 sizeof(int32_t);
127 }
128
129 if (idx >= num_rows || idx < 0) {
130 if (output->type == kTfLiteString) {
131 buf.AddString(nullptr, 0);
132 } else {
133 memset(output->data.raw + i * row_bytes, 0, row_bytes);
134 }
135 hits->data.uint8[i] = 0;
136 } else {
137 if (output->type == kTfLiteString) {
138 buf.AddString(GetString(value, idx));
139 } else {
140 memcpy(output->data.raw + i * row_bytes,
141 value->data.raw + idx * row_bytes, row_bytes);
142 }
143 hits->data.uint8[i] = 1;
144 }
145 }
146 if (output->type == kTfLiteString) {
147 buf.WriteToTensorAsVector(output);
148 }
149
150 return kTfLiteOk;
151}
152} // namespace
153
154TfLiteRegistration* Register_HASHTABLE_LOOKUP() {
155 static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval};
156 return &r;
157}
158
159} // namespace builtin
160} // namespace ops
161} // namespace tflite
162