1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * \author guonix
17 * \date Dec 2020
18 * \brief
19 */
20
21
22#include <memory>
23#include <gtest/gtest.h>
24#include <meta/meta_impl.h>
25#include "index/mock_index_service.h" // for MockIndexService
26#include "index/mock_segment.h" // for MockSegment
27#include "meta/mock_meta_service.h" // for MockMetaService
28#include "query/query_service_builder.h"
29
30using QueryRequest = proxima::be::proto::QueryRequest;
31using QueryResponse = proxima::be::proto::QueryResponse;
32using GetDocumentRequest = proxima::be::proto::GetDocumentRequest;
33using GetDocumentResponse = proxima::be::proto::GetDocumentResponse;
34
35using KnnParam = proxima::be::proto::QueryRequest::KnnQueryParam;
36
37
38class QueryServiceTest : public Test {
39 protected:
40 // Sets up the test fixture.
41 void SetUp() override {
42 meta_service_ = std::make_shared<MockMetaService>();
43 index_service_ = std::make_shared<MockIndexService>();
44
45 init_request();
46 init_response();
47 }
48
49 // Tears down the test fixture.
50 void TearDown() override {
51 meta_service_.reset();
52 index_service_.reset();
53 }
54
55 private:
56 void init_request() {
57 request_.reset(new (std::nothrow) QueryRequest());
58 request_->set_query_type(proxima::be::proto::QueryRequest_QueryType_QT_KNN);
59 request_->set_collection_name(collection_);
60 // not supported yet
61 request_->set_debug_mode(false);
62 // Allocate KNN Param
63 param_ = request_->mutable_knn_param();
64
65 param_->set_column_name("column_name");
66 param_->set_topk(1);
67 param_->set_dimension(10);
68 param_->set_data_type(proxima::be::proto::DataType::DT_VECTOR_FP16);
69 param_->set_features("features");
70 param_->set_batch_count(1);
71 param_->set_radius(0.1f);
72
73 auto kv = param_->add_extra_params();
74 kv->set_key("string_key1");
75 kv->set_value("value1");
76 kv = param_->add_extra_params();
77 kv->set_key("int_key1");
78 kv->set_value("10");
79
80 equal_request_.reset(new (std::nothrow) GetDocumentRequest());
81 equal_request_->set_collection_name(collection_);
82 // not supported yet
83 equal_request_->set_debug_mode(false);
84 equal_request_->set_primary_key(1);
85 }
86
87 void init_response() {
88 response_.reset(new (std::nothrow) QueryResponse());
89 }
90
91 protected:
92 MockMetaServicePtr meta_service_{nullptr};
93 MockIndexServicePtr index_service_{nullptr};
94 std::unique_ptr<QueryRequest> request_{nullptr};
95 std::unique_ptr<QueryResponse> response_{nullptr};
96 KnnParam *param_{nullptr};
97 std::unique_ptr<GetDocumentRequest> equal_request_{nullptr};
98 // std::unique_ptr<GetDocumentResponse> equal_response_{nullptr};
99
100 std::string collection_{"unittest"};
101};
102
103using QueryService = proxima::be::query::QueryService;
104using QueryServiceBuilder = proxima::be::query::QueryServiceBuilder;
105
106TEST_F(QueryServiceTest, TestInitialize) {
107 EXPECT_FALSE(QueryServiceBuilder::Create(nullptr, meta_service_, 1));
108 EXPECT_FALSE(QueryServiceBuilder::Create(nullptr, nullptr, 1));
109 EXPECT_FALSE(QueryServiceBuilder::Create(index_service_, nullptr, 1));
110
111 auto svc = QueryServiceBuilder::Create(index_service_, meta_service_, 1);
112 EXPECT_TRUE(svc);
113 EXPECT_TRUE(svc->initialized());
114 EXPECT_TRUE(svc->cleanup() == 0);
115}
116
117TEST_F(QueryServiceTest, TestSearch) {
118 { // Invalid params
119 auto svc = QueryServiceBuilder::Create(index_service_, meta_service_, 1);
120 EXPECT_TRUE(svc);
121 EXPECT_TRUE(svc->initialized());
122 EXPECT_TRUE(svc->search(nullptr, nullptr, nullptr) != 0);
123 QueryRequest request;
124 request.set_query_type(
125 proxima::be::proto::
126 QueryRequest_QueryType_QueryRequest_QueryType_INT_MIN_SENTINEL_DO_NOT_USE_);
127 EXPECT_TRUE(svc->search(&request, nullptr, nullptr) != 0);
128 QueryResponse response;
129 EXPECT_TRUE(svc->search(nullptr, &response, nullptr) != 0);
130 svc->cleanup();
131 }
132
133 { // Valid KNN Search
134 // Mock Meta
135 CollectionMeta collection_meta;
136 collection_meta.mutable_forward_columns()->push_back("forward1");
137 collection_meta.mutable_forward_columns()->push_back("forward2");
138 auto column1 = std::make_shared<ColumnMeta>("column_name");
139 column1->set_data_type(DataTypes::VECTOR_FP16);
140 collection_meta.append(column1);
141
142 CollectionImplPtr collection =
143 std::make_shared<CollectionImpl>(collection_meta);
144 // Return collection
145 EXPECT_CALL(*meta_service_, get_current_collection(_))
146 .WillRepeatedly(
147 Invoke([&collection](const std::string &) -> CollectionMetaPtr {
148 return collection->meta();
149 }))
150 .RetiresOnSaturation(); // success
151
152 EXPECT_CALL(*meta_service_, get_collection(_, _))
153 .WillOnce(Invoke([&collection](const std::string &, uint64_t revision) {
154 EXPECT_EQ(revision, 1u);
155 return collection->meta();
156 }))
157 .RetiresOnSaturation();
158
159 auto segment = std::make_shared<MockSegment>();
160 testing::Mock::AllowLeak(static_cast<void *>(segment.get()));
161 // Set results
162 EXPECT_CALL(*segment, knn_search(_, _, _, _, _))
163 .WillOnce(Invoke([](const std::string &, const std::string &query,
164 const QueryParams &, uint32_t batch,
165 std::vector<QueryResultList> *results) {
166 results->clear();
167 EXPECT_EQ(batch, 1);
168 EXPECT_EQ(query, "features");
169 QueryResult result;
170 result.primary_key = 1U;
171 result.lsn = 1U;
172 result.revision = 1;
173 result.score = 0.95f;
174 proxima::be::proto::GenericValueList values;
175 auto value = values.add_values();
176 value->set_int32_value(10);
177 value = values.add_values();
178 value->set_string_value("str_value");
179 // Forward
180 result.forward_data.assign(values.SerializeAsString());
181 results->push_back({result});
182 return 0;
183 }))
184 .RetiresOnSaturation();
185
186 EXPECT_CALL(*index_service_, list_segments(_, _))
187 .WillOnce(Invoke([&segment](const std::string &,
188 index::SegmentPtrList *segments) -> int {
189 segments->push_back(segment);
190 return 0;
191 }))
192 .RetiresOnSaturation();
193
194 auto svc = QueryServiceBuilder::Create(index_service_, meta_service_, 1);
195 EXPECT_TRUE(svc);
196 auto profiler = std::make_shared<proxima::be::Profiler>(false);
197 EXPECT_EQ(svc->search(request_.get(), response_.get(), profiler), 0);
198 EXPECT_EQ(response_->results_size(), 1);
199 EXPECT_EQ(response_->results(0).documents_size(), 1);
200
201 EXPECT_EQ(response_->results(0).documents(0).primary_key(), 1U);
202 EXPECT_EQ(response_->results(0).documents(0).forward_column_values_size(),
203 2);
204
205 auto &kv = response_->results(0).documents(0).forward_column_values(0);
206 EXPECT_EQ(kv.key(), "forward1");
207 EXPECT_EQ(kv.value().int32_value(), 10);
208 auto &kv1 = response_->results(0).documents(0).forward_column_values(1);
209 EXPECT_EQ(kv1.key(), "forward2");
210 EXPECT_EQ(kv1.value().string_value(), "str_value");
211
212 response_->Clear();
213 }
214
215 //{ // Valid Equal Search
216 // // failed, no mocks with meta service
217 // EXPECT_TRUE(svc.search(equal_request_, response_) != 0);
218 // response_->Clear();
219 //}
220}
221