1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * \author guonix
17 * \date Nov 2020
18 * \brief
19 */
20
21#include "knn_query.h"
22#include "common/error_code.h"
23#include "common/logger.h"
24#include "common/transformer.h"
25#include "common/types_helper.h"
26
27namespace proxima {
28namespace be {
29namespace query {
30
31//! Constructor
32KNNQuery::KNNQuery(uint64_t traceID, const proto::QueryRequest *req,
33 index::IndexServicePtr index, MetaWrapperPtr meta_wrapper,
34 ExecutorPtr executor_ptr, ProfilerPtr profiler_ptr,
35 proto::QueryResponse *resp)
36 : CollectionQuery(traceID, req, std::move(index), std::move(meta_wrapper),
37 std::move(executor_ptr), std::move(profiler_ptr), resp) {}
38
39//! Destructor
40KNNQuery::~KNNQuery() = default;
41
42//! Validate query object, 0 for valid, otherwise non zero returned
43int KNNQuery::validate() const {
44 ScopedLatency latency("validate", profiler());
45 int code = CollectionQuery::validate();
46 if (code == 0) {
47 if (valid_response() && valid_executor()) {
48 code = meta()->validate_column(collection(), column());
49 } else {
50 LOG_WARN("Invalid response or executor passed to KNNQuery");
51 code = PROXIMA_BE_ERROR_CODE(InvalidArgument);
52 }
53 }
54 return code;
55}
56
57//! Retrieve IOMode of query
58IOMode KNNQuery::mode() const {
59 return IOMode::READONLY;
60}
61
62//! Retrieve the type of query, Readonly
63QueryType KNNQuery::type() const {
64 return QueryType::KNN;
65}
66
67//! Prepare resources, 0 for success, otherwise failed
68int KNNQuery::prepare() {
69 ScopedLatency latency("prepare", profiler());
70 index::SegmentPtrList segments;
71 int code = list_segments(&segments);
72 if (code != 0) {
73 return code;
74 }
75
76 for (auto &segment : segments) {
77 std::string knn_name("knn_task_");
78 knn_name.append(std::to_string(segment->segment_id()));
79 knn_name.append("_");
80 knn_name.append(std::to_string(id()));
81 tasks_.emplace_back(std::make_shared<KNNTask>(knn_name, segment, this));
82 }
83 code = build_query_param(request()->knn_param());
84 if (code != 0) {
85 LOG_ERROR("Failed build query param from request");
86 return code;
87 }
88
89 code = transform_feature(request()->knn_param());
90 if (code != 0) {
91 LOG_ERROR("Failed transform features. code[%d] what[%s]", code,
92 ErrorCode::What(code));
93 }
94 return code;
95}
96
97//! Evaluate query, and collection feedback
98int KNNQuery::evaluate() {
99 profiler()->open_stage("evaluate");
100 TaskPtrList tasks(tasks_.begin(), tasks_.end());
101 int code;
102 {
103 ScopedLatency execute_latency("execute", profiler());
104 code = executor()->execute_tasks(tasks);
105 }
106 if (code == 0) {
107 ScopedLatency merge_and_sort_latency("merge_and_sort", profiler());
108 // Merge Result
109 code = collect_result();
110 }
111 profiler()->close_stage();
112 return code;
113}
114
115//! Finalize query object
116int KNNQuery::finalize() {
117 return 0;
118}
119
120namespace {
121
122static int CollectBatchResult(const KNNTaskPtrList &tasks, uint32_t batch,
123 KNNQuery::ResultRefHeap *results) {
124 for (auto &task : tasks) {
125 if (static_cast<size_t>(batch) < task->result().size()) {
126 for (const auto &iter : task->result()[batch]) {
127 results->push(iter);
128 // Optimization: skip remained result, which more lower than last one
129 // in target heap
130 if (results->begin()->get() < iter) {
131 break;
132 }
133 }
134 } else {
135 return PROXIMA_BE_ERROR_CODE(OutOfBoundsResult);
136 }
137 }
138 return 0;
139}
140
141} // namespace
142
143int KNNQuery::collect_result() {
144 if (tasks_.empty()) {
145 return 0;
146 }
147
148 int code = 0;
149 for (uint32_t batch = 0; batch < batch_count(); batch++) {
150 ResultRefHeap results;
151 results.limit(query_param_.topk);
152 // Gather all the reference into list
153 code = CollectBatchResult(tasks_, batch, &results);
154 if (code == 0) {
155 // Transform heap to sorted vector
156 results.sort();
157 // Feed entity field
158 if (feed_entity(results, mutable_response()->add_results()) !=
159 query_param_.topk) {
160 LOG_DEBUG("No enough results to fill response");
161 }
162 } else {
163 LOG_ERROR("Collect result failed");
164 break;
165 }
166 }
167
168 return code;
169}
170
171//! Retrieve column name
172const std::string &KNNQuery::column() const {
173 return request()->knn_param().column_name();
174}
175
176const std::string &KNNQuery::features() const {
177 return features_;
178}
179
180uint32_t KNNQuery::batch_count() const {
181 return request()->knn_param().batch_count();
182}
183
184const index::QueryParams &KNNQuery::query_params() const {
185 return query_param_;
186}
187
188int KNNQuery::build_query_param(
189 const proto::QueryRequest::KnnQueryParam &param) {
190 query_param_.query_id = id();
191 query_param_.topk = param.topk();
192 query_param_.data_type = be::DataTypeCodeBook::Get(param.data_type());
193 query_param_.dimension = param.dimension();
194 query_param_.radius = param.radius();
195 query_param_.is_linear = param.is_linear();
196 be::IndexParamsHelper::SerializeToParams(param.extra_params(),
197 &query_param_.extra_params);
198 return 0;
199}
200
201uint32_t KNNQuery::feed_entity(const ResultRefList &results,
202 proto::QueryResponse::Result *result) {
203 for (const auto &iter : results) {
204 proto::Document *doc = result->add_documents();
205 doc->set_primary_key(iter.get().primary_key);
206 doc->set_score(iter.get().score);
207 // Fill forward for document
208 fill_forward(iter.get(), doc);
209 }
210 return results.size();
211}
212
213int KNNQuery::transform_feature(
214 const proto::QueryRequest::KnnQueryParam &param) {
215 int code = PROXIMA_BE_ERROR_CODE(InvalidQuery);
216 auto data_type = meta()->get_data_type(collection(), column());
217 auto value_case = param.features_value_case();
218 if (value_case == proto::QueryRequest_KnnQueryParam::kFeatures) {
219 code = Transformer::Transform(query_param_.data_type, param.features(),
220 data_type, &features_);
221 } else if (value_case == proto::QueryRequest_KnnQueryParam::kMatrix) {
222 std::function<int(const ailego::JsonValue &)> validator =
223 [this](const ailego::JsonValue &node) -> int {
224 int ret = PROXIMA_BE_ERROR_CODE(InvalidVectorFormat);
225 if (node.is_array()) {
226 auto &array = node.as_array();
227 if (!array.empty()) {
228 if (array.begin()->is_array()) {
229 ret = 0;
230 for (auto it = array.begin(); it != array.end(); ++it) {
231 if (!it->is_array() ||
232 it->as_array().size() != query_param_.dimension) {
233 ret |= PROXIMA_BE_ERROR_CODE(InvalidVectorFormat);
234 break;
235 }
236 }
237 } else if (array.size() == query_param_.dimension) {
238 ret = 0;
239 }
240 }
241 }
242 return ret;
243 };
244 switch (data_type) {
245 case DataTypes::VECTOR_FP32: {
246 std::vector<float> values;
247 Transformer::Transform(param.matrix(), &validator, &values);
248 Primary2Bytes::Bytes<float, DataTypes::VECTOR_FP32>(values, &features_);
249 code = 0;
250 break;
251 }
252
253 case DataTypes::VECTOR_FP16: {
254 std::vector<float> values;
255 Transformer::Transform(param.matrix(), &validator, &values);
256 Primary2Bytes::Bytes<float, DataTypes::VECTOR_FP16>(values, &features_);
257 code = 0;
258 break;
259 }
260 case DataTypes::VECTOR_INT16: {
261 std::vector<int16_t> values;
262 Transformer::Transform(param.matrix(), &validator, &values);
263 Primary2Bytes::Bytes<int16_t, DataTypes::VECTOR_INT16>(values,
264 &features_);
265 code = 0;
266 break;
267 }
268 case DataTypes::VECTOR_INT8: {
269 std::vector<int8_t> values;
270 Transformer::Transform(param.matrix(), &validator, &values);
271 Primary2Bytes::Bytes<int8_t, DataTypes::VECTOR_INT8>(values,
272 &features_);
273 code = 0;
274 break;
275 }
276 case DataTypes::VECTOR_INT4: {
277 std::vector<int8_t> values;
278 Transformer::Transform(param.matrix(), &validator, &values);
279 Primary2Bytes::Bytes<int8_t, DataTypes::VECTOR_INT4>(values,
280 &features_);
281 code = 0;
282 break;
283 }
284 case DataTypes::VECTOR_BINARY32: {
285 std::vector<uint32_t> values;
286 Transformer::Transform(param.matrix(), &validator, &values);
287 Primary2Bytes::Bytes<uint32_t, DataTypes::VECTOR_BINARY32>(values,
288 &features_);
289 code = 0;
290 break;
291 }
292 case DataTypes::VECTOR_BINARY64: {
293 std::vector<uint64_t> values;
294 Transformer::Transform(param.matrix(), &validator, &values);
295 Primary2Bytes::Bytes<uint64_t, DataTypes::VECTOR_BINARY64>(values,
296 &features_);
297 code = 0;
298 break;
299 }
300 default:
301 LOG_ERROR("Unsupported data type %u.", (uint32_t)data_type);
302 code = PROXIMA_BE_ERROR_CODE(InvalidDataType);
303 }
304 }
305 return code;
306}
307
308} // namespace query
309} // namespace be
310} // namespace proxima
311