1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * \author guonix
17 * \date Dec 2020
18 * \brief
19 */
20
21#include "query/knn_query.h"
22#include <gtest/gtest.h>
23#include <meta/meta_impl.h>
24#include "index/mock_index_service.h" // for MockIndexService
25#include "index/mock_segment.h" // for MockSegment
26#include "meta/mock_meta_service.h" // for MockMetaService
27#include "mock_executor.h" // for MockExecutor
28#include "mock_query_context.h" // for Mock*Context
29
30using QueryRequest = proxima::be::proto::QueryRequest;
31using QueryResponse = proxima::be::proto::QueryResponse;
32using KnnParam = proxima::be::proto::QueryRequest::KnnQueryParam;
33
34class KNNQueryTest : public Test {
35 protected:
36 // Sets up the test fixture.
37 void SetUp() override {
38 init_request();
39 init_response();
40 }
41
42 // Tears down the test fixture.
43 void TearDown() override {
44 cleanup_request();
45 cleanup_response();
46 }
47
48 private:
49 void init_request() {
50 request_ = new (std::nothrow) QueryRequest();
51 request_->set_query_type(proxima::be::proto::QueryRequest_QueryType_QT_KNN);
52 request_->set_collection_name(collection_);
53 // not supported yet
54 request_->set_debug_mode(false);
55 // Allocate KNN Param
56 param_ = request_->mutable_knn_param();
57
58 param_->set_column_name("column_name");
59 param_->set_topk(3);
60 param_->set_dimension(10);
61 param_->set_data_type(proto::DataType::DT_VECTOR_FP16);
62 param_->set_features("features");
63 param_->set_batch_count(1);
64 param_->set_radius(0.1f);
65
66 auto kv = param_->add_extra_params();
67 kv->set_key("string_key1");
68 kv->set_value("value1");
69 kv = param_->add_extra_params();
70 kv->set_key("int_key1");
71 kv->set_value("10");
72 }
73
74 void init_response() {
75 response_ = new (std::nothrow) QueryResponse();
76 }
77
78 template <typename Pointer>
79 void delete_pointer_if(Pointer *ptr) {
80 delete ptr;
81 ptr = nullptr;
82 }
83
84 void cleanup_request() {
85 delete_pointer_if(request_);
86 }
87
88 void cleanup_response() {
89 delete_pointer_if(response_);
90 }
91
92 protected:
93 QueryRequest *request_{nullptr};
94 KnnParam *param_{nullptr};
95 QueryResponse *response_{nullptr};
96 std::string collection_{"unittest"};
97};
98
99TEST_F(KNNQueryTest, TestBaseFunctional) {
100 auto meta_service = std::make_shared<MockMetaService>();
101
102 auto meta = std::make_shared<MetaWrapper>(meta_service);
103 // Test invalid params
104 auto knn = std::make_shared<KNNQuery>(0, nullptr, nullptr, nullptr, nullptr,
105 nullptr, nullptr);
106
107 EXPECT_EQ(knn->mode(), IOMode::READONLY);
108 EXPECT_EQ(knn->type(), QueryType::KNN);
109 EXPECT_EQ(knn->id(), 0);
110}
111
112TEST_F(KNNQueryTest, TestValidate) {
113 auto executor = std::make_shared<MockExecutor>();
114 auto meta_service = std::make_shared<MockMetaService>();
115 auto index_service = std::make_shared<MockIndexService>();
116
117 auto meta = std::make_shared<MetaWrapper>(meta_service);
118 // Test invalid params
119 auto knn = std::make_shared<KNNQuery>(
120 0, nullptr, index_service, meta, executor,
121 std::make_shared<proxima::be::Profiler>(false), response_);
122 EXPECT_TRUE(knn->validate() != 0);
123
124 knn.reset(new (std::nothrow) KNNQuery(
125 0, request_, index_service, meta, nullptr,
126 std::make_shared<proxima::be::Profiler>(false), response_));
127 EXPECT_TRUE(knn->validate() != 0);
128
129 knn.reset(new (std::nothrow) KNNQuery(
130 0, request_, index_service, meta, executor,
131 std::make_shared<proxima::be::Profiler>(false), nullptr));
132 EXPECT_TRUE(knn->validate() != 0);
133
134
135 CollectionMeta collection_meta;
136 collection_meta.mutable_forward_columns()->push_back("forward1");
137 collection_meta.mutable_forward_columns()->push_back("forward2");
138 auto column1 = std::make_shared<ColumnMeta>("column_name");
139 collection_meta.append(column1);
140
141 CollectionImplPtr collection =
142 std::make_shared<CollectionImpl>(collection_meta);
143
144 // Return 0, with invalid collection meta
145 EXPECT_CALL(*meta_service, get_current_collection(_))
146 .WillOnce(Invoke(
147 [](const std::string &) -> CollectionMetaPtr { return nullptr; }))
148 .RetiresOnSaturation(); // success
149 // Set all right arguments
150 knn.reset(new (std::nothrow) KNNQuery(
151 0, request_, index_service, meta, executor,
152 std::make_shared<proxima::be::Profiler>(false), response_));
153 // Can't valid column from meta wrapper
154 EXPECT_TRUE(knn->validate() != 0);
155
156
157 // Return collection
158 EXPECT_CALL(*meta_service, get_current_collection("unittest"))
159 .WillOnce(Invoke([&collection](const std::string &) -> CollectionMetaPtr {
160 return collection->meta();
161 }))
162 .RetiresOnSaturation(); // success
163
164 EXPECT_EQ(knn->validate(), 0);
165
166 EXPECT_EQ(knn->column(), "column_name");
167 EXPECT_EQ(knn->batch_count(), 1);
168}
169
170TEST_F(KNNQueryTest, TestPrepare) {
171 auto executor = std::make_shared<MockExecutor>();
172 auto meta_service = std::make_shared<MockMetaService>();
173 auto index_service = std::make_shared<MockIndexService>();
174
175 EXPECT_CALL(*index_service, list_segments(collection_, _))
176 .WillOnce(Return(1))
177 .WillOnce(Return(0)) // Success but no available segments
178 .RetiresOnSaturation();
179
180 auto meta = std::make_shared<MetaWrapper>(meta_service);
181 auto knn = std::make_shared<KNNQuery>(
182 0, request_, index_service, meta, executor,
183 std::make_shared<proxima::be::Profiler>(false), response_);
184 EXPECT_TRUE(knn->prepare() != 0);
185 EXPECT_TRUE(knn->prepare() != 0);
186}
187
188TEST_F(KNNQueryTest, TestEvaluate) {
189 auto executor = std::make_shared<MockExecutor>();
190 auto meta_service = std::make_shared<MockMetaService>();
191 auto index_service = std::make_shared<MockIndexService>();
192 auto segment = std::make_shared<MockSegment>();
193
194 EXPECT_CALL(*index_service, list_segments(_, _))
195 .WillRepeatedly(
196 Invoke([&segment](const std::string &,
197 index::SegmentPtrList *segments) -> int {
198 EXPECT_TRUE(segments != nullptr);
199 segments->push_back(segment);
200 return 0;
201 }));
202
203 CollectionMeta collection_meta;
204 collection_meta.mutable_forward_columns()->push_back("forward1");
205 collection_meta.mutable_forward_columns()->push_back("forward2");
206 auto column1 = std::make_shared<ColumnMeta>("column_name");
207 column1->set_data_type(DataTypes::VECTOR_FP16);
208 collection_meta.append(column1);
209
210 CollectionImplPtr collection =
211 std::make_shared<CollectionImpl>(collection_meta);
212
213 EXPECT_CALL(*meta_service, get_current_collection(_))
214 .WillOnce(Invoke(
215 [](const std::string &) -> CollectionMetaPtr { return nullptr; }))
216 .WillOnce(Invoke([&collection](const std::string &) -> CollectionMetaPtr {
217 return collection->meta();
218 }))
219 .RetiresOnSaturation(); // success
220
221 auto meta = std::make_shared<MetaWrapper>(meta_service);
222 auto knn = std::make_shared<KNNQuery>(
223 0, request_, index_service, meta, executor,
224 std::make_shared<proxima::be::Profiler>(false), response_);
225 EXPECT_EQ(knn->prepare(), PROXIMA_BE_ERROR_CODE(MismatchedDataType));
226 EXPECT_EQ(knn->prepare(), 0);
227
228 { // evaluate failed with fake execute
229 EXPECT_CALL(*executor, execute_tasks(_))
230 .WillOnce(Return(0))
231 .RetiresOnSaturation();
232
233 EXPECT_TRUE(knn->evaluate() != 0);
234 }
235
236 { // Evaluate success, but no enough values
237 // Execute task
238 EXPECT_CALL(*executor, execute_tasks(_))
239 .WillOnce(Invoke([](const TaskPtrList &tasks) {
240 for (auto &task : tasks) {
241 task->status(Task::Status::SCHEDULED);
242 task->run();
243 }
244 return 0;
245 })) // Fake Execute
246 .RetiresOnSaturation();
247
248 // Set results
249 EXPECT_CALL(*segment, knn_search(_, _, _, _, _))
250 .WillRepeatedly(Invoke([](const std::string &, const std::string &,
251 const QueryParams &, uint32_t batch,
252 std::vector<QueryResultList> *results) {
253 EXPECT_EQ(batch, 1);
254 results->push_back({});
255 return 0;
256 }))
257 .RetiresOnSaturation();
258
259 // No enough results
260 EXPECT_EQ(knn->evaluate(), 0);
261 }
262
263 response_->Clear();
264
265 { // Test serialize
266 EXPECT_CALL(*executor, execute_tasks(_))
267 .WillRepeatedly(Invoke([](const TaskPtrList &tasks) {
268 for (auto &task : tasks) {
269 task->status(Task::Status::SCHEDULED);
270 task->run();
271 }
272 return 0;
273 })) // Fake Execute
274 .RetiresOnSaturation();
275
276 CollectionImplPtr collection_impl = nullptr;
277 { // Init collection
278 CollectionMeta temp_meta;
279 temp_meta.mutable_forward_columns()->push_back("forward1");
280 temp_meta.mutable_forward_columns()->push_back("forward2");
281 collection_impl.reset(new CollectionImpl(temp_meta));
282 }
283
284 testing::Mock::AllowLeak(static_cast<void *>(meta_service.get()));
285 EXPECT_CALL(*meta_service, get_collection(_, _))
286 .WillRepeatedly(Invoke([&collection_impl](const std::string &collection,
287 uint64_t revision) {
288 EXPECT_EQ(revision, 1u);
289 return collection_impl->meta();
290 }))
291 .RetiresOnSaturation();
292
293 // Set results
294 EXPECT_CALL(*segment, knn_search(_, _, _, _, _))
295 .WillRepeatedly(Invoke([](const std::string &, const std::string &,
296 const QueryParams &, uint32_t batch,
297 std::vector<QueryResultList> *results) {
298 results->clear();
299
300 EXPECT_EQ(batch, 1);
301 {
302 QueryResult result95;
303 result95.primary_key = 1U;
304 result95.lsn = 1U;
305 result95.revision = 1;
306 result95.score = 0.95f;
307 {
308 proxima::be::proto::GenericValueList values;
309 auto value = values.add_values();
310 value->set_int32_value(10);
311 value = values.add_values();
312 value->set_string_value("strvalue");
313 // Forward
314 result95.forward_data.assign(values.SerializeAsString());
315 }
316
317 QueryResult result96;
318 result96.primary_key = 2U;
319 result96.lsn = 1U;
320 result96.revision = 1;
321 result96.score = 0.96f;
322 {
323 proxima::be::proto::GenericValueList values;
324 auto value = values.add_values();
325 value->set_int32_value(10);
326 value = values.add_values();
327 value->set_string_value("strvalue");
328 // Forward
329 result96.forward_data.assign(values.SerializeAsString());
330 }
331 QueryResult result93;
332 result93.primary_key = 3U;
333 result93.lsn = 1U;
334 result93.revision = 1;
335 result93.score = 0.93f;
336 {
337 proxima::be::proto::GenericValueList values;
338 auto value = values.add_values();
339 value->set_int32_value(10);
340 value = values.add_values();
341 value->set_string_value("strvalue");
342 // Forward
343 result93.forward_data.assign(values.SerializeAsString());
344 }
345 results->push_back({result93, result95, result96});
346 }
347 return 0;
348 }))
349 .RetiresOnSaturation();
350
351 // No enough results
352 EXPECT_EQ(knn->evaluate(), 0);
353 EXPECT_EQ(response_->results_size(), 1);
354 EXPECT_EQ(response_->results(0).documents_size(), 3);
355
356 EXPECT_EQ(response_->results(0).documents(0).primary_key(), 3U);
357 EXPECT_EQ(response_->results(0).documents(0).forward_column_values_size(),
358 2);
359 auto &kv = response_->results(0).documents(0).forward_column_values(0);
360 EXPECT_EQ(kv.key(), "forward1");
361 EXPECT_EQ(kv.value().int32_value(), 10);
362 auto &kv1 = response_->results(0).documents(0).forward_column_values(1);
363 EXPECT_EQ(kv1.key(), "forward2");
364 EXPECT_EQ(kv1.value().string_value(), "strvalue");
365 EXPECT_EQ(response_->results(0).documents(1).primary_key(), 3U);
366 EXPECT_EQ(response_->results(0).documents(2).primary_key(), 1U);
367 }
368
369 response_->Clear();
370}
371
372TEST_F(KNNQueryTest, TestFinalize) {
373 auto executor = std::make_shared<MockExecutor>();
374 auto meta_service = std::make_shared<MockMetaService>();
375 auto index_service = std::make_shared<MockIndexService>();
376 auto meta = std::make_shared<MetaWrapper>(meta_service);
377
378 auto knn = std::make_shared<KNNQuery>(
379 0, request_, index_service, meta, executor,
380 std::make_shared<proxima::be::Profiler>(false), response_);
381
382 EXPECT_EQ(knn->finalize(), 0);
383}
384