1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// LSH Projection projects an input to a bit vector via locality sensitive
17// hashing.
18//
19// Options:
20// Sparse:
21// Computed bit vector is considered to be sparse.
22// Each output element is an int32 made up by multiple bits computed from
23// hash functions.
24//
25// Dense:
26// Computed bit vector is considered to be dense. Each output element is
27// either 0 or 1 that represents a bit.
28//
29// Input:
30// Tensor[0]: Hash functions. Dim.size == 2, DataType: Float.
31// Tensor[0].Dim[0]: Num of hash functions. Must be at least 1.
32// Tensor[0].Dim[1]: Num of projected output bits generated by
33// each hash function.
34// In sparse case, Tensor[0].Dim[1] + ceil( log2(Tensor[0].Dim[0] )) <= 32.
35//
36// Tensor[1]: Input. Dim.size >= 1, No restriction on DataType.
37// Tensor[2]: Optional, Weight. Dim.size == 1, DataType: Float.
38// If not set, each element of input is considered to have same
39// weight of 1.0 Tensor[1].Dim[0] == Tensor[2].Dim[0]
40//
41// Output:
42// Sparse:
43// Output.Dim == { Tensor[0].Dim[0] }
44// A tensor of int32 that represents hash signatures,
45//
46// NOTE: To avoid collisions across hash functions, an offset value of
47// k * (1 << Tensor[0].Dim[1]) will be added to each signature,
48// k is the index of the hash function.
49// Dense:
50// Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] }
51// A flattened tensor represents projected bit vectors.
52
53#include <stddef.h>
54#include <stdint.h>
55
56#include <cstring>
57#include <memory>
58
59#include "tensorflow/lite/c/builtin_op_data.h"
60#include "tensorflow/lite/c/common.h"
61#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
62#include "tensorflow/lite/kernels/kernel_util.h"
63#include <farmhash.h>
64
65namespace tflite {
66namespace ops {
67namespace builtin {
68namespace lsh_projection {
69
70TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
71 auto* params =
72 reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
73 TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
74 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
75
76 const TfLiteTensor* hash;
77 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
78 TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2);
79 // Support up to 32 bits.
80 TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32);
81
82 const TfLiteTensor* input;
83 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
84 TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
85 TF_LITE_ENSURE(context, SizeOfDimension(input, 0) >= 1);
86
87 if (NumInputs(node) == 3) {
88 const TfLiteTensor* weight;
89 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &weight));
90 TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1);
91 TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0),
92 SizeOfDimension(input, 0));
93 }
94
95 TfLiteTensor* output;
96 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
97 TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
98 switch (params->type) {
99 case kTfLiteLshProjectionSparse:
100 outputSize->data[0] = SizeOfDimension(hash, 0);
101 break;
102 case kTfLiteLshProjectionDense:
103 outputSize->data[0] = SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1);
104 break;
105 default:
106 return kTfLiteError;
107 }
108 return context->ResizeTensor(context, output, outputSize);
109}
110
111// Compute sign bit of dot product of hash(seed, input) and weight.
112// NOTE: use float as seed, and convert it to double as a temporary solution
113// to match the trained model. This is going to be changed once the new
114// model is trained in an optimized method.
115//
116int RunningSignBit(const TfLiteTensor* input, const TfLiteTensor* weight,
117 float seed) {
118 double score = 0.0;
119 int input_item_bytes = input->bytes / SizeOfDimension(input, 0);
120 char* input_ptr = input->data.raw;
121
122 const size_t seed_size = sizeof(float);
123 const size_t key_bytes = sizeof(float) + input_item_bytes;
124 std::unique_ptr<char[]> key(new char[key_bytes]);
125
126 const float* weight_ptr = GetTensorData<float>(weight);
127
128 for (int i = 0; i < SizeOfDimension(input, 0); ++i) {
129 // Create running hash id and value for current dimension.
130 memcpy(key.get(), &seed, seed_size);
131 memcpy(key.get() + seed_size, input_ptr, input_item_bytes);
132
133 int64_t hash_signature = ::util::Fingerprint64(key.get(), key_bytes);
134 double running_value = static_cast<double>(hash_signature);
135 input_ptr += input_item_bytes;
136 if (weight_ptr == nullptr) {
137 score += running_value;
138 } else {
139 score += weight_ptr[i] * running_value;
140 }
141 }
142
143 return (score > 0) ? 1 : 0;
144}
145
146void SparseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input,
147 const TfLiteTensor* weight, int32_t* out_buf) {
148 int num_hash = SizeOfDimension(hash, 0);
149 int num_bits = SizeOfDimension(hash, 1);
150 for (int i = 0; i < num_hash; i++) {
151 int32_t hash_signature = 0;
152 for (int j = 0; j < num_bits; j++) {
153 float seed = GetTensorData<float>(hash)[i * num_bits + j];
154 int bit = RunningSignBit(input, weight, seed);
155 hash_signature = (hash_signature << 1) | bit;
156 }
157 *out_buf++ = hash_signature + i * (1 << num_bits);
158 }
159}
160
161void DenseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input,
162 const TfLiteTensor* weight, int32_t* out_buf) {
163 int num_hash = SizeOfDimension(hash, 0);
164 int num_bits = SizeOfDimension(hash, 1);
165 for (int i = 0; i < num_hash; i++) {
166 for (int j = 0; j < num_bits; j++) {
167 float seed = GetTensorData<float>(hash)[i * num_bits + j];
168 int bit = RunningSignBit(input, weight, seed);
169 *out_buf++ = bit;
170 }
171 }
172}
173
174TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
175 auto* params =
176 reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
177
178 TfLiteTensor* out_tensor;
179 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor));
180 int32_t* out_buf = out_tensor->data.i32;
181 const TfLiteTensor* hash;
182 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash));
183 const TfLiteTensor* input;
184 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input));
185 const TfLiteTensor* weight =
186 NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2);
187
188 switch (params->type) {
189 case kTfLiteLshProjectionDense:
190 DenseLshProjection(hash, input, weight, out_buf);
191 break;
192 case kTfLiteLshProjectionSparse:
193 SparseLshProjection(hash, input, weight, out_buf);
194 break;
195 default:
196 return kTfLiteError;
197 }
198
199 return kTfLiteOk;
200}
201} // namespace lsh_projection
202
203TfLiteRegistration* Register_LSH_PROJECTION() {
204 static TfLiteRegistration r = {nullptr, nullptr, lsh_projection::Resize,
205 lsh_projection::Eval};
206 return &r;
207}
208
209} // namespace builtin
210} // namespace ops
211} // namespace tflite
212