1 | /** |
2 | * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | * |
16 | * \author guonix |
17 | * \date Dec 2020 |
18 | * \brief |
19 | */ |
20 | |
21 | #include "query/knn_query.h" |
22 | #include <gtest/gtest.h> |
23 | #include <meta/meta_impl.h> |
24 | #include "index/mock_index_service.h" // for MockIndexService |
25 | #include "index/mock_segment.h" // for MockSegment |
26 | #include "meta/mock_meta_service.h" // for MockMetaService |
27 | #include "mock_executor.h" // for MockExecutor |
28 | #include "mock_query_context.h" // for Mock*Context |
29 | |
30 | using QueryRequest = proxima::be::proto::QueryRequest; |
31 | using QueryResponse = proxima::be::proto::QueryResponse; |
32 | using KnnParam = proxima::be::proto::QueryRequest::KnnQueryParam; |
33 | |
34 | class KNNQueryTest : public Test { |
35 | protected: |
36 | // Sets up the test fixture. |
37 | void SetUp() override { |
38 | init_request(); |
39 | init_response(); |
40 | } |
41 | |
42 | // Tears down the test fixture. |
43 | void TearDown() override { |
44 | cleanup_request(); |
45 | cleanup_response(); |
46 | } |
47 | |
48 | private: |
49 | void init_request() { |
50 | request_ = new (std::nothrow) QueryRequest(); |
51 | request_->set_query_type(proxima::be::proto::QueryRequest_QueryType_QT_KNN); |
52 | request_->set_collection_name(collection_); |
53 | // not supported yet |
54 | request_->set_debug_mode(false); |
55 | // Allocate KNN Param |
56 | param_ = request_->mutable_knn_param(); |
57 | |
58 | param_->set_column_name("column_name" ); |
59 | param_->set_topk(3); |
60 | param_->set_dimension(10); |
61 | param_->set_data_type(proto::DataType::DT_VECTOR_FP16); |
62 | param_->set_features("features" ); |
63 | param_->set_batch_count(1); |
64 | param_->set_radius(0.1f); |
65 | |
66 | auto kv = param_->add_extra_params(); |
67 | kv->set_key("string_key1" ); |
68 | kv->set_value("value1" ); |
69 | kv = param_->add_extra_params(); |
70 | kv->set_key("int_key1" ); |
71 | kv->set_value("10" ); |
72 | } |
73 | |
74 | void init_response() { |
75 | response_ = new (std::nothrow) QueryResponse(); |
76 | } |
77 | |
78 | template <typename Pointer> |
79 | void delete_pointer_if(Pointer *ptr) { |
80 | delete ptr; |
81 | ptr = nullptr; |
82 | } |
83 | |
84 | void cleanup_request() { |
85 | delete_pointer_if(request_); |
86 | } |
87 | |
88 | void cleanup_response() { |
89 | delete_pointer_if(response_); |
90 | } |
91 | |
92 | protected: |
93 | QueryRequest *request_{nullptr}; |
94 | KnnParam *param_{nullptr}; |
95 | QueryResponse *response_{nullptr}; |
96 | std::string collection_{"unittest" }; |
97 | }; |
98 | |
99 | TEST_F(KNNQueryTest, TestBaseFunctional) { |
100 | auto meta_service = std::make_shared<MockMetaService>(); |
101 | |
102 | auto meta = std::make_shared<MetaWrapper>(meta_service); |
103 | // Test invalid params |
104 | auto knn = std::make_shared<KNNQuery>(0, nullptr, nullptr, nullptr, nullptr, |
105 | nullptr, nullptr); |
106 | |
107 | EXPECT_EQ(knn->mode(), IOMode::READONLY); |
108 | EXPECT_EQ(knn->type(), QueryType::KNN); |
109 | EXPECT_EQ(knn->id(), 0); |
110 | } |
111 | |
112 | TEST_F(KNNQueryTest, TestValidate) { |
113 | auto executor = std::make_shared<MockExecutor>(); |
114 | auto meta_service = std::make_shared<MockMetaService>(); |
115 | auto index_service = std::make_shared<MockIndexService>(); |
116 | |
117 | auto meta = std::make_shared<MetaWrapper>(meta_service); |
118 | // Test invalid params |
119 | auto knn = std::make_shared<KNNQuery>( |
120 | 0, nullptr, index_service, meta, executor, |
121 | std::make_shared<proxima::be::Profiler>(false), response_); |
122 | EXPECT_TRUE(knn->validate() != 0); |
123 | |
124 | knn.reset(new (std::nothrow) KNNQuery( |
125 | 0, request_, index_service, meta, nullptr, |
126 | std::make_shared<proxima::be::Profiler>(false), response_)); |
127 | EXPECT_TRUE(knn->validate() != 0); |
128 | |
129 | knn.reset(new (std::nothrow) KNNQuery( |
130 | 0, request_, index_service, meta, executor, |
131 | std::make_shared<proxima::be::Profiler>(false), nullptr)); |
132 | EXPECT_TRUE(knn->validate() != 0); |
133 | |
134 | |
135 | CollectionMeta collection_meta; |
136 | collection_meta.mutable_forward_columns()->push_back("forward1" ); |
137 | collection_meta.mutable_forward_columns()->push_back("forward2" ); |
138 | auto column1 = std::make_shared<ColumnMeta>("column_name" ); |
139 | collection_meta.append(column1); |
140 | |
141 | CollectionImplPtr collection = |
142 | std::make_shared<CollectionImpl>(collection_meta); |
143 | |
144 | // Return 0, with invalid collection meta |
145 | EXPECT_CALL(*meta_service, get_current_collection(_)) |
146 | .WillOnce(Invoke( |
147 | [](const std::string &) -> CollectionMetaPtr { return nullptr; })) |
148 | .RetiresOnSaturation(); // success |
149 | // Set all right arguments |
150 | knn.reset(new (std::nothrow) KNNQuery( |
151 | 0, request_, index_service, meta, executor, |
152 | std::make_shared<proxima::be::Profiler>(false), response_)); |
153 | // Can't valid column from meta wrapper |
154 | EXPECT_TRUE(knn->validate() != 0); |
155 | |
156 | |
157 | // Return collection |
158 | EXPECT_CALL(*meta_service, get_current_collection("unittest" )) |
159 | .WillOnce(Invoke([&collection](const std::string &) -> CollectionMetaPtr { |
160 | return collection->meta(); |
161 | })) |
162 | .RetiresOnSaturation(); // success |
163 | |
164 | EXPECT_EQ(knn->validate(), 0); |
165 | |
166 | EXPECT_EQ(knn->column(), "column_name" ); |
167 | EXPECT_EQ(knn->batch_count(), 1); |
168 | } |
169 | |
170 | TEST_F(KNNQueryTest, TestPrepare) { |
171 | auto executor = std::make_shared<MockExecutor>(); |
172 | auto meta_service = std::make_shared<MockMetaService>(); |
173 | auto index_service = std::make_shared<MockIndexService>(); |
174 | |
175 | EXPECT_CALL(*index_service, list_segments(collection_, _)) |
176 | .WillOnce(Return(1)) |
177 | .WillOnce(Return(0)) // Success but no available segments |
178 | .RetiresOnSaturation(); |
179 | |
180 | auto meta = std::make_shared<MetaWrapper>(meta_service); |
181 | auto knn = std::make_shared<KNNQuery>( |
182 | 0, request_, index_service, meta, executor, |
183 | std::make_shared<proxima::be::Profiler>(false), response_); |
184 | EXPECT_TRUE(knn->prepare() != 0); |
185 | EXPECT_TRUE(knn->prepare() != 0); |
186 | } |
187 | |
188 | TEST_F(KNNQueryTest, TestEvaluate) { |
189 | auto executor = std::make_shared<MockExecutor>(); |
190 | auto meta_service = std::make_shared<MockMetaService>(); |
191 | auto index_service = std::make_shared<MockIndexService>(); |
192 | auto segment = std::make_shared<MockSegment>(); |
193 | |
194 | EXPECT_CALL(*index_service, list_segments(_, _)) |
195 | .WillRepeatedly( |
196 | Invoke([&segment](const std::string &, |
197 | index::SegmentPtrList *segments) -> int { |
198 | EXPECT_TRUE(segments != nullptr); |
199 | segments->push_back(segment); |
200 | return 0; |
201 | })); |
202 | |
203 | CollectionMeta collection_meta; |
204 | collection_meta.mutable_forward_columns()->push_back("forward1" ); |
205 | collection_meta.mutable_forward_columns()->push_back("forward2" ); |
206 | auto column1 = std::make_shared<ColumnMeta>("column_name" ); |
207 | column1->set_data_type(DataTypes::VECTOR_FP16); |
208 | collection_meta.append(column1); |
209 | |
210 | CollectionImplPtr collection = |
211 | std::make_shared<CollectionImpl>(collection_meta); |
212 | |
213 | EXPECT_CALL(*meta_service, get_current_collection(_)) |
214 | .WillOnce(Invoke( |
215 | [](const std::string &) -> CollectionMetaPtr { return nullptr; })) |
216 | .WillOnce(Invoke([&collection](const std::string &) -> CollectionMetaPtr { |
217 | return collection->meta(); |
218 | })) |
219 | .RetiresOnSaturation(); // success |
220 | |
221 | auto meta = std::make_shared<MetaWrapper>(meta_service); |
222 | auto knn = std::make_shared<KNNQuery>( |
223 | 0, request_, index_service, meta, executor, |
224 | std::make_shared<proxima::be::Profiler>(false), response_); |
225 | EXPECT_EQ(knn->prepare(), PROXIMA_BE_ERROR_CODE(MismatchedDataType)); |
226 | EXPECT_EQ(knn->prepare(), 0); |
227 | |
228 | { // evaluate failed with fake execute |
229 | EXPECT_CALL(*executor, execute_tasks(_)) |
230 | .WillOnce(Return(0)) |
231 | .RetiresOnSaturation(); |
232 | |
233 | EXPECT_TRUE(knn->evaluate() != 0); |
234 | } |
235 | |
236 | { // Evaluate success, but no enough values |
237 | // Execute task |
238 | EXPECT_CALL(*executor, execute_tasks(_)) |
239 | .WillOnce(Invoke([](const TaskPtrList &tasks) { |
240 | for (auto &task : tasks) { |
241 | task->status(Task::Status::SCHEDULED); |
242 | task->run(); |
243 | } |
244 | return 0; |
245 | })) // Fake Execute |
246 | .RetiresOnSaturation(); |
247 | |
248 | // Set results |
249 | EXPECT_CALL(*segment, knn_search(_, _, _, _, _)) |
250 | .WillRepeatedly(Invoke([](const std::string &, const std::string &, |
251 | const QueryParams &, uint32_t batch, |
252 | std::vector<QueryResultList> *results) { |
253 | EXPECT_EQ(batch, 1); |
254 | results->push_back({}); |
255 | return 0; |
256 | })) |
257 | .RetiresOnSaturation(); |
258 | |
259 | // No enough results |
260 | EXPECT_EQ(knn->evaluate(), 0); |
261 | } |
262 | |
263 | response_->Clear(); |
264 | |
265 | { // Test serialize |
266 | EXPECT_CALL(*executor, execute_tasks(_)) |
267 | .WillRepeatedly(Invoke([](const TaskPtrList &tasks) { |
268 | for (auto &task : tasks) { |
269 | task->status(Task::Status::SCHEDULED); |
270 | task->run(); |
271 | } |
272 | return 0; |
273 | })) // Fake Execute |
274 | .RetiresOnSaturation(); |
275 | |
276 | CollectionImplPtr collection_impl = nullptr; |
277 | { // Init collection |
278 | CollectionMeta temp_meta; |
279 | temp_meta.mutable_forward_columns()->push_back("forward1" ); |
280 | temp_meta.mutable_forward_columns()->push_back("forward2" ); |
281 | collection_impl.reset(new CollectionImpl(temp_meta)); |
282 | } |
283 | |
284 | testing::Mock::AllowLeak(static_cast<void *>(meta_service.get())); |
285 | EXPECT_CALL(*meta_service, get_collection(_, _)) |
286 | .WillRepeatedly(Invoke([&collection_impl](const std::string &collection, |
287 | uint64_t revision) { |
288 | EXPECT_EQ(revision, 1u); |
289 | return collection_impl->meta(); |
290 | })) |
291 | .RetiresOnSaturation(); |
292 | |
293 | // Set results |
294 | EXPECT_CALL(*segment, knn_search(_, _, _, _, _)) |
295 | .WillRepeatedly(Invoke([](const std::string &, const std::string &, |
296 | const QueryParams &, uint32_t batch, |
297 | std::vector<QueryResultList> *results) { |
298 | results->clear(); |
299 | |
300 | EXPECT_EQ(batch, 1); |
301 | { |
302 | QueryResult result95; |
303 | result95.primary_key = 1U; |
304 | result95.lsn = 1U; |
305 | result95.revision = 1; |
306 | result95.score = 0.95f; |
307 | { |
308 | proxima::be::proto::GenericValueList values; |
309 | auto value = values.add_values(); |
310 | value->set_int32_value(10); |
311 | value = values.add_values(); |
312 | value->set_string_value("strvalue" ); |
313 | // Forward |
314 | result95.forward_data.assign(values.SerializeAsString()); |
315 | } |
316 | |
317 | QueryResult result96; |
318 | result96.primary_key = 2U; |
319 | result96.lsn = 1U; |
320 | result96.revision = 1; |
321 | result96.score = 0.96f; |
322 | { |
323 | proxima::be::proto::GenericValueList values; |
324 | auto value = values.add_values(); |
325 | value->set_int32_value(10); |
326 | value = values.add_values(); |
327 | value->set_string_value("strvalue" ); |
328 | // Forward |
329 | result96.forward_data.assign(values.SerializeAsString()); |
330 | } |
331 | QueryResult result93; |
332 | result93.primary_key = 3U; |
333 | result93.lsn = 1U; |
334 | result93.revision = 1; |
335 | result93.score = 0.93f; |
336 | { |
337 | proxima::be::proto::GenericValueList values; |
338 | auto value = values.add_values(); |
339 | value->set_int32_value(10); |
340 | value = values.add_values(); |
341 | value->set_string_value("strvalue" ); |
342 | // Forward |
343 | result93.forward_data.assign(values.SerializeAsString()); |
344 | } |
345 | results->push_back({result93, result95, result96}); |
346 | } |
347 | return 0; |
348 | })) |
349 | .RetiresOnSaturation(); |
350 | |
351 | // No enough results |
352 | EXPECT_EQ(knn->evaluate(), 0); |
353 | EXPECT_EQ(response_->results_size(), 1); |
354 | EXPECT_EQ(response_->results(0).documents_size(), 3); |
355 | |
356 | EXPECT_EQ(response_->results(0).documents(0).primary_key(), 3U); |
357 | EXPECT_EQ(response_->results(0).documents(0).forward_column_values_size(), |
358 | 2); |
359 | auto &kv = response_->results(0).documents(0).forward_column_values(0); |
360 | EXPECT_EQ(kv.key(), "forward1" ); |
361 | EXPECT_EQ(kv.value().int32_value(), 10); |
362 | auto &kv1 = response_->results(0).documents(0).forward_column_values(1); |
363 | EXPECT_EQ(kv1.key(), "forward2" ); |
364 | EXPECT_EQ(kv1.value().string_value(), "strvalue" ); |
365 | EXPECT_EQ(response_->results(0).documents(1).primary_key(), 3U); |
366 | EXPECT_EQ(response_->results(0).documents(2).primary_key(), 1U); |
367 | } |
368 | |
369 | response_->Clear(); |
370 | } |
371 | |
372 | TEST_F(KNNQueryTest, TestFinalize) { |
373 | auto executor = std::make_shared<MockExecutor>(); |
374 | auto meta_service = std::make_shared<MockMetaService>(); |
375 | auto index_service = std::make_shared<MockIndexService>(); |
376 | auto meta = std::make_shared<MetaWrapper>(meta_service); |
377 | |
378 | auto knn = std::make_shared<KNNQuery>( |
379 | 0, request_, index_service, meta, executor, |
380 | std::make_shared<proxima::be::Profiler>(false), response_); |
381 | |
382 | EXPECT_EQ(knn->finalize(), 0); |
383 | } |
384 | |