1 | /** |
2 | * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | * |
16 | * \author guonix |
17 | * \date Nov 2020 |
18 | * \brief |
19 | */ |
20 | |
21 | #include "knn_query.h" |
22 | #include "common/error_code.h" |
23 | #include "common/logger.h" |
24 | #include "common/transformer.h" |
25 | #include "common/types_helper.h" |
26 | |
27 | namespace proxima { |
28 | namespace be { |
29 | namespace query { |
30 | |
31 | //! Constructor |
32 | KNNQuery::KNNQuery(uint64_t traceID, const proto::QueryRequest *req, |
33 | index::IndexServicePtr index, MetaWrapperPtr meta_wrapper, |
34 | ExecutorPtr executor_ptr, ProfilerPtr profiler_ptr, |
35 | proto::QueryResponse *resp) |
36 | : CollectionQuery(traceID, req, std::move(index), std::move(meta_wrapper), |
37 | std::move(executor_ptr), std::move(profiler_ptr), resp) {} |
38 | |
39 | //! Destructor |
40 | KNNQuery::~KNNQuery() = default; |
41 | |
42 | //! Validate query object, 0 for valid, otherwise non zero returned |
43 | int KNNQuery::validate() const { |
44 | ScopedLatency latency("validate" , profiler()); |
45 | int code = CollectionQuery::validate(); |
46 | if (code == 0) { |
47 | if (valid_response() && valid_executor()) { |
48 | code = meta()->validate_column(collection(), column()); |
49 | } else { |
50 | LOG_WARN("Invalid response or executor passed to KNNQuery" ); |
51 | code = PROXIMA_BE_ERROR_CODE(InvalidArgument); |
52 | } |
53 | } |
54 | return code; |
55 | } |
56 | |
57 | //! Retrieve IOMode of query |
58 | IOMode KNNQuery::mode() const { |
59 | return IOMode::READONLY; |
60 | } |
61 | |
62 | //! Retrieve the type of query, Readonly |
63 | QueryType KNNQuery::type() const { |
64 | return QueryType::KNN; |
65 | } |
66 | |
67 | //! Prepare resources, 0 for success, otherwise failed |
68 | int KNNQuery::prepare() { |
69 | ScopedLatency latency("prepare" , profiler()); |
70 | index::SegmentPtrList segments; |
71 | int code = list_segments(&segments); |
72 | if (code != 0) { |
73 | return code; |
74 | } |
75 | |
76 | for (auto &segment : segments) { |
77 | std::string knn_name("knn_task_" ); |
78 | knn_name.append(std::to_string(segment->segment_id())); |
79 | knn_name.append("_" ); |
80 | knn_name.append(std::to_string(id())); |
81 | tasks_.emplace_back(std::make_shared<KNNTask>(knn_name, segment, this)); |
82 | } |
83 | code = build_query_param(request()->knn_param()); |
84 | if (code != 0) { |
85 | LOG_ERROR("Failed build query param from request" ); |
86 | return code; |
87 | } |
88 | |
89 | code = transform_feature(request()->knn_param()); |
90 | if (code != 0) { |
91 | LOG_ERROR("Failed transform features. code[%d] what[%s]" , code, |
92 | ErrorCode::What(code)); |
93 | } |
94 | return code; |
95 | } |
96 | |
97 | //! Evaluate query, and collection feedback |
98 | int KNNQuery::evaluate() { |
99 | profiler()->open_stage("evaluate" ); |
100 | TaskPtrList tasks(tasks_.begin(), tasks_.end()); |
101 | int code; |
102 | { |
103 | ScopedLatency execute_latency("execute" , profiler()); |
104 | code = executor()->execute_tasks(tasks); |
105 | } |
106 | if (code == 0) { |
107 | ScopedLatency merge_and_sort_latency("merge_and_sort" , profiler()); |
108 | // Merge Result |
109 | code = collect_result(); |
110 | } |
111 | profiler()->close_stage(); |
112 | return code; |
113 | } |
114 | |
115 | //! Finalize query object |
116 | int KNNQuery::finalize() { |
117 | return 0; |
118 | } |
119 | |
120 | namespace { |
121 | |
122 | static int CollectBatchResult(const KNNTaskPtrList &tasks, uint32_t batch, |
123 | KNNQuery::ResultRefHeap *results) { |
124 | for (auto &task : tasks) { |
125 | if (static_cast<size_t>(batch) < task->result().size()) { |
126 | for (const auto &iter : task->result()[batch]) { |
127 | results->push(iter); |
128 | // Optimization: skip remained result, which more lower than last one |
129 | // in target heap |
130 | if (results->begin()->get() < iter) { |
131 | break; |
132 | } |
133 | } |
134 | } else { |
135 | return PROXIMA_BE_ERROR_CODE(OutOfBoundsResult); |
136 | } |
137 | } |
138 | return 0; |
139 | } |
140 | |
141 | } // namespace |
142 | |
143 | int KNNQuery::collect_result() { |
144 | if (tasks_.empty()) { |
145 | return 0; |
146 | } |
147 | |
148 | int code = 0; |
149 | for (uint32_t batch = 0; batch < batch_count(); batch++) { |
150 | ResultRefHeap results; |
151 | results.limit(query_param_.topk); |
152 | // Gather all the reference into list |
153 | code = CollectBatchResult(tasks_, batch, &results); |
154 | if (code == 0) { |
155 | // Transform heap to sorted vector |
156 | results.sort(); |
157 | // Feed entity field |
158 | if (feed_entity(results, mutable_response()->add_results()) != |
159 | query_param_.topk) { |
160 | LOG_DEBUG("No enough results to fill response" ); |
161 | } |
162 | } else { |
163 | LOG_ERROR("Collect result failed" ); |
164 | break; |
165 | } |
166 | } |
167 | |
168 | return code; |
169 | } |
170 | |
171 | //! Retrieve column name |
172 | const std::string &KNNQuery::column() const { |
173 | return request()->knn_param().column_name(); |
174 | } |
175 | |
176 | const std::string &KNNQuery::features() const { |
177 | return features_; |
178 | } |
179 | |
180 | uint32_t KNNQuery::batch_count() const { |
181 | return request()->knn_param().batch_count(); |
182 | } |
183 | |
184 | const index::QueryParams &KNNQuery::query_params() const { |
185 | return query_param_; |
186 | } |
187 | |
188 | int KNNQuery::build_query_param( |
189 | const proto::QueryRequest::KnnQueryParam ¶m) { |
190 | query_param_.query_id = id(); |
191 | query_param_.topk = param.topk(); |
192 | query_param_.data_type = be::DataTypeCodeBook::Get(param.data_type()); |
193 | query_param_.dimension = param.dimension(); |
194 | query_param_.radius = param.radius(); |
195 | query_param_.is_linear = param.is_linear(); |
196 | be::IndexParamsHelper::SerializeToParams(param.extra_params(), |
197 | &query_param_.extra_params); |
198 | return 0; |
199 | } |
200 | |
201 | uint32_t KNNQuery::feed_entity(const ResultRefList &results, |
202 | proto::QueryResponse::Result *result) { |
203 | for (const auto &iter : results) { |
204 | proto::Document *doc = result->add_documents(); |
205 | doc->set_primary_key(iter.get().primary_key); |
206 | doc->set_score(iter.get().score); |
207 | // Fill forward for document |
208 | fill_forward(iter.get(), doc); |
209 | } |
210 | return results.size(); |
211 | } |
212 | |
213 | int KNNQuery::transform_feature( |
214 | const proto::QueryRequest::KnnQueryParam ¶m) { |
215 | int code = PROXIMA_BE_ERROR_CODE(InvalidQuery); |
216 | auto data_type = meta()->get_data_type(collection(), column()); |
217 | auto value_case = param.features_value_case(); |
218 | if (value_case == proto::QueryRequest_KnnQueryParam::kFeatures) { |
219 | code = Transformer::Transform(query_param_.data_type, param.features(), |
220 | data_type, &features_); |
221 | } else if (value_case == proto::QueryRequest_KnnQueryParam::kMatrix) { |
222 | std::function<int(const ailego::JsonValue &)> validator = |
223 | [this](const ailego::JsonValue &node) -> int { |
224 | int ret = PROXIMA_BE_ERROR_CODE(InvalidVectorFormat); |
225 | if (node.is_array()) { |
226 | auto &array = node.as_array(); |
227 | if (!array.empty()) { |
228 | if (array.begin()->is_array()) { |
229 | ret = 0; |
230 | for (auto it = array.begin(); it != array.end(); ++it) { |
231 | if (!it->is_array() || |
232 | it->as_array().size() != query_param_.dimension) { |
233 | ret |= PROXIMA_BE_ERROR_CODE(InvalidVectorFormat); |
234 | break; |
235 | } |
236 | } |
237 | } else if (array.size() == query_param_.dimension) { |
238 | ret = 0; |
239 | } |
240 | } |
241 | } |
242 | return ret; |
243 | }; |
244 | switch (data_type) { |
245 | case DataTypes::VECTOR_FP32: { |
246 | std::vector<float> values; |
247 | Transformer::Transform(param.matrix(), &validator, &values); |
248 | Primary2Bytes::Bytes<float, DataTypes::VECTOR_FP32>(values, &features_); |
249 | code = 0; |
250 | break; |
251 | } |
252 | |
253 | case DataTypes::VECTOR_FP16: { |
254 | std::vector<float> values; |
255 | Transformer::Transform(param.matrix(), &validator, &values); |
256 | Primary2Bytes::Bytes<float, DataTypes::VECTOR_FP16>(values, &features_); |
257 | code = 0; |
258 | break; |
259 | } |
260 | case DataTypes::VECTOR_INT16: { |
261 | std::vector<int16_t> values; |
262 | Transformer::Transform(param.matrix(), &validator, &values); |
263 | Primary2Bytes::Bytes<int16_t, DataTypes::VECTOR_INT16>(values, |
264 | &features_); |
265 | code = 0; |
266 | break; |
267 | } |
268 | case DataTypes::VECTOR_INT8: { |
269 | std::vector<int8_t> values; |
270 | Transformer::Transform(param.matrix(), &validator, &values); |
271 | Primary2Bytes::Bytes<int8_t, DataTypes::VECTOR_INT8>(values, |
272 | &features_); |
273 | code = 0; |
274 | break; |
275 | } |
276 | case DataTypes::VECTOR_INT4: { |
277 | std::vector<int8_t> values; |
278 | Transformer::Transform(param.matrix(), &validator, &values); |
279 | Primary2Bytes::Bytes<int8_t, DataTypes::VECTOR_INT4>(values, |
280 | &features_); |
281 | code = 0; |
282 | break; |
283 | } |
284 | case DataTypes::VECTOR_BINARY32: { |
285 | std::vector<uint32_t> values; |
286 | Transformer::Transform(param.matrix(), &validator, &values); |
287 | Primary2Bytes::Bytes<uint32_t, DataTypes::VECTOR_BINARY32>(values, |
288 | &features_); |
289 | code = 0; |
290 | break; |
291 | } |
292 | case DataTypes::VECTOR_BINARY64: { |
293 | std::vector<uint64_t> values; |
294 | Transformer::Transform(param.matrix(), &validator, &values); |
295 | Primary2Bytes::Bytes<uint64_t, DataTypes::VECTOR_BINARY64>(values, |
296 | &features_); |
297 | code = 0; |
298 | break; |
299 | } |
300 | default: |
301 | LOG_ERROR("Unsupported data type %u." , (uint32_t)data_type); |
302 | code = PROXIMA_BE_ERROR_CODE(InvalidDataType); |
303 | } |
304 | } |
305 | return code; |
306 | } |
307 | |
308 | } // namespace query |
309 | } // namespace be |
310 | } // namespace proxima |
311 | |