1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // LSH Projection projects an input to a bit vector via locality sensitive |
17 | // hashing. |
18 | // |
19 | // Options: |
20 | // Sparse: |
21 | // Computed bit vector is considered to be sparse. |
22 | // Each output element is an int32 made up by multiple bits computed from |
23 | // hash functions. |
24 | // |
25 | // Dense: |
26 | // Computed bit vector is considered to be dense. Each output element is |
27 | // either 0 or 1 that represents a bit. |
28 | // |
29 | // Input: |
30 | // Tensor[0]: Hash functions. Dim.size == 2, DataType: Float. |
31 | // Tensor[0].Dim[0]: Num of hash functions. Must be at least 1. |
32 | // Tensor[0].Dim[1]: Num of projected output bits generated by |
33 | // each hash function. |
34 | // In sparse case, Tensor[0].Dim[1] + ceil( log2(Tensor[0].Dim[0] )) <= 32. |
35 | // |
36 | // Tensor[1]: Input. Dim.size >= 1, No restriction on DataType. |
37 | // Tensor[2]: Optional, Weight. Dim.size == 1, DataType: Float. |
38 | // If not set, each element of input is considered to have same |
39 | // weight of 1.0 Tensor[1].Dim[0] == Tensor[2].Dim[0] |
40 | // |
41 | // Output: |
42 | // Sparse: |
43 | // Output.Dim == { Tensor[0].Dim[0] } |
44 | // A tensor of int32 that represents hash signatures, |
45 | // |
46 | // NOTE: To avoid collisions across hash functions, an offset value of |
47 | // k * (1 << Tensor[0].Dim[1]) will be added to each signature, |
48 | // k is the index of the hash function. |
49 | // Dense: |
50 | // Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } |
51 | // A flattened tensor represents projected bit vectors. |
52 | |
53 | #include <stddef.h> |
54 | #include <stdint.h> |
55 | |
56 | #include <cstring> |
57 | #include <memory> |
58 | |
59 | #include "tensorflow/lite/c/builtin_op_data.h" |
60 | #include "tensorflow/lite/c/common.h" |
61 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
62 | #include "tensorflow/lite/kernels/kernel_util.h" |
63 | #include <farmhash.h> |
64 | |
65 | namespace tflite { |
66 | namespace ops { |
67 | namespace builtin { |
68 | namespace lsh_projection { |
69 | |
70 | TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { |
71 | auto* params = |
72 | reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data); |
73 | TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); |
74 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
75 | |
76 | const TfLiteTensor* hash; |
77 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash)); |
78 | TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2); |
79 | // Support up to 32 bits. |
80 | TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32); |
81 | |
82 | const TfLiteTensor* input; |
83 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input)); |
84 | TF_LITE_ENSURE(context, NumDimensions(input) >= 1); |
85 | TF_LITE_ENSURE(context, SizeOfDimension(input, 0) >= 1); |
86 | |
87 | if (NumInputs(node) == 3) { |
88 | const TfLiteTensor* weight; |
89 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &weight)); |
90 | TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1); |
91 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0), |
92 | SizeOfDimension(input, 0)); |
93 | } |
94 | |
95 | TfLiteTensor* output; |
96 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
97 | TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); |
98 | switch (params->type) { |
99 | case kTfLiteLshProjectionSparse: |
100 | outputSize->data[0] = SizeOfDimension(hash, 0); |
101 | break; |
102 | case kTfLiteLshProjectionDense: |
103 | outputSize->data[0] = SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1); |
104 | break; |
105 | default: |
106 | return kTfLiteError; |
107 | } |
108 | return context->ResizeTensor(context, output, outputSize); |
109 | } |
110 | |
111 | // Compute sign bit of dot product of hash(seed, input) and weight. |
112 | // NOTE: use float as seed, and convert it to double as a temporary solution |
113 | // to match the trained model. This is going to be changed once the new |
114 | // model is trained in an optimized method. |
115 | // |
116 | int RunningSignBit(const TfLiteTensor* input, const TfLiteTensor* weight, |
117 | float seed) { |
118 | double score = 0.0; |
119 | int input_item_bytes = input->bytes / SizeOfDimension(input, 0); |
120 | char* input_ptr = input->data.raw; |
121 | |
122 | const size_t seed_size = sizeof(float); |
123 | const size_t key_bytes = sizeof(float) + input_item_bytes; |
124 | std::unique_ptr<char[]> key(new char[key_bytes]); |
125 | |
126 | const float* weight_ptr = GetTensorData<float>(weight); |
127 | |
128 | for (int i = 0; i < SizeOfDimension(input, 0); ++i) { |
129 | // Create running hash id and value for current dimension. |
130 | memcpy(key.get(), &seed, seed_size); |
131 | memcpy(key.get() + seed_size, input_ptr, input_item_bytes); |
132 | |
133 | int64_t hash_signature = ::util::Fingerprint64(key.get(), key_bytes); |
134 | double running_value = static_cast<double>(hash_signature); |
135 | input_ptr += input_item_bytes; |
136 | if (weight_ptr == nullptr) { |
137 | score += running_value; |
138 | } else { |
139 | score += weight_ptr[i] * running_value; |
140 | } |
141 | } |
142 | |
143 | return (score > 0) ? 1 : 0; |
144 | } |
145 | |
146 | void SparseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, |
147 | const TfLiteTensor* weight, int32_t* out_buf) { |
148 | int num_hash = SizeOfDimension(hash, 0); |
149 | int num_bits = SizeOfDimension(hash, 1); |
150 | for (int i = 0; i < num_hash; i++) { |
151 | int32_t hash_signature = 0; |
152 | for (int j = 0; j < num_bits; j++) { |
153 | float seed = GetTensorData<float>(hash)[i * num_bits + j]; |
154 | int bit = RunningSignBit(input, weight, seed); |
155 | hash_signature = (hash_signature << 1) | bit; |
156 | } |
157 | *out_buf++ = hash_signature + i * (1 << num_bits); |
158 | } |
159 | } |
160 | |
161 | void DenseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, |
162 | const TfLiteTensor* weight, int32_t* out_buf) { |
163 | int num_hash = SizeOfDimension(hash, 0); |
164 | int num_bits = SizeOfDimension(hash, 1); |
165 | for (int i = 0; i < num_hash; i++) { |
166 | for (int j = 0; j < num_bits; j++) { |
167 | float seed = GetTensorData<float>(hash)[i * num_bits + j]; |
168 | int bit = RunningSignBit(input, weight, seed); |
169 | *out_buf++ = bit; |
170 | } |
171 | } |
172 | } |
173 | |
174 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
175 | auto* params = |
176 | reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data); |
177 | |
178 | TfLiteTensor* out_tensor; |
179 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &out_tensor)); |
180 | int32_t* out_buf = out_tensor->data.i32; |
181 | const TfLiteTensor* hash; |
182 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &hash)); |
183 | const TfLiteTensor* input; |
184 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &input)); |
185 | const TfLiteTensor* weight = |
186 | NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2); |
187 | |
188 | switch (params->type) { |
189 | case kTfLiteLshProjectionDense: |
190 | DenseLshProjection(hash, input, weight, out_buf); |
191 | break; |
192 | case kTfLiteLshProjectionSparse: |
193 | SparseLshProjection(hash, input, weight, out_buf); |
194 | break; |
195 | default: |
196 | return kTfLiteError; |
197 | } |
198 | |
199 | return kTfLiteOk; |
200 | } |
201 | } // namespace lsh_projection |
202 | |
203 | TfLiteRegistration* Register_LSH_PROJECTION() { |
204 | static TfLiteRegistration r = {nullptr, nullptr, lsh_projection::Resize, |
205 | lsh_projection::Eval}; |
206 | return &r; |
207 | } |
208 | |
209 | } // namespace builtin |
210 | } // namespace ops |
211 | } // namespace tflite |
212 | |