1 | /** |
2 | * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #define private public |
18 | #define protected public |
19 | #include "common/config.h" |
20 | #undef private |
21 | #undef protected |
22 | |
23 | #include <ailego/utility/file_helper.h> |
24 | #include <gtest/gtest.h> |
25 | #include "common/defer.h" |
26 | #include "server/proxima_search_engine.h" |
27 | #include "port_helper.h" |
28 | #include "proxima_search_client.h" |
29 | |
30 | #ifdef proxima_search_engine_test_VERSION |
31 | #define PROXIMA_BE_VERSION_STRING proxima_search_engine_test_VERSION |
32 | #else |
33 | #define PROXIMA_BE_VERSION_STRING "unknown" |
34 | #endif |
35 | |
36 | using namespace proxima::be; |
37 | using namespace proxima::be::server; |
38 | |
39 | class ProximaSearchEngineTest : public ::testing::Test { |
40 | protected: |
41 | virtual void SetUp() { |
42 | ailego::FileHelper::RemoveDirectory("./test_proxima_be/" ); |
43 | |
44 | int pid; |
45 | PortHelper::GetPort(&grpc_port_, &pid); |
46 | PortHelper::GetPort(&http_port_, &pid); |
47 | PortHelper::RemovePortFile(pid); |
48 | |
49 | auto &config = Config::Instance(); |
50 | config.config_.mutable_common_config()->set_logger_type("ConsoleLogger" ); |
51 | config.config_.mutable_common_config()->set_log_directory( |
52 | "./test_proxima_be/log/" ); |
53 | config.config_.mutable_common_config()->set_protocol("grpc" ); |
54 | config.config_.mutable_common_config()->set_grpc_listen_port(grpc_port_); |
55 | config.config_.mutable_common_config()->set_http_listen_port(http_port_); |
56 | config.config_.mutable_index_config()->set_index_directory( |
57 | "./test_proxima_be/index_data/" ); |
58 | std::string work_directory; |
59 | ailego::FileHelper::GetWorkingDirectory(&work_directory); |
60 | std::string meta_uri = "sqlite://" + work_directory + |
61 | "/test_proxima_be/proxima_be_meta.sqlite" ; |
62 | config.config_.mutable_meta_config()->set_meta_uri(meta_uri); |
63 | } |
64 | |
65 | virtual void TearDown() {} |
66 | |
67 | protected: |
68 | int grpc_port_; |
69 | int http_port_; |
70 | }; |
71 | |
72 | TEST_F(ProximaSearchEngineTest, TestClient) { |
73 | auto &engine = ProximaSearchEngine::Instance(); |
74 | |
75 | int ret = engine.init(false, "" ); |
76 | ASSERT_EQ(ret, 0); |
77 | |
78 | Defer defer([&engine] { |
79 | engine.stop(); |
80 | engine.cleanup(); |
81 | }); |
82 | |
83 | engine.set_version(PROXIMA_BE_VERSION_STRING); |
84 | ret = engine.start(); |
85 | ASSERT_EQ(ret, 0); |
86 | |
87 | // Create a client |
88 | ProximaSearchClientPtr client = ProximaSearchClient::Create(); |
89 | ASSERT_TRUE(client != nullptr); |
90 | |
91 | // Connect to server |
92 | ChannelOptions options(std::string("127.0.0.1:" ) + |
93 | std::to_string(grpc_port_)); |
94 | options.timeout_ms = 60000U; |
95 | Status status = client->connect(options); |
96 | ASSERT_EQ(status.reason, "Success" ); |
97 | ASSERT_EQ(status.code, 0); |
98 | |
99 | // Create collection |
100 | CollectionConfig config; |
101 | config.collection_name = "test_collection" ; |
102 | config.forward_columns = {"fwd_column1" , "fwd_column2" , "fwd_column3" , |
103 | "fwd_column4" }; |
104 | config.index_columns = { |
105 | IndexColumnParam("test_column" , DataType::VECTOR_FP32, 8)}; |
106 | status = client->create_collection(config); |
107 | ASSERT_EQ(status.reason, "Success" ); |
108 | ASSERT_EQ(status.code, 0); |
109 | |
110 | // Describe collection |
111 | CollectionInfo collection_info; |
112 | status = client->describe_collection("test_collection" , &collection_info); |
113 | ASSERT_EQ(status.reason, "Success" ); |
114 | ASSERT_EQ(status.code, 0); |
115 | ASSERT_EQ(collection_info.collection_name, "test_collection" ); |
116 | ASSERT_EQ(collection_info.forward_columns.size(), 4); |
117 | ASSERT_EQ(collection_info.forward_columns[0], "fwd_column1" ); |
118 | ASSERT_EQ(collection_info.forward_columns[1], "fwd_column2" ); |
119 | ASSERT_EQ(collection_info.forward_columns[2], "fwd_column3" ); |
120 | ASSERT_EQ(collection_info.forward_columns[3], "fwd_column4" ); |
121 | ASSERT_EQ(collection_info.index_columns.size(), 1); |
122 | ASSERT_EQ(collection_info.index_columns[0].column_name, "test_column" ); |
123 | ASSERT_EQ(collection_info.index_columns[0].index_type, |
124 | IndexType::PROXIMA_GRAPH_INDEX); |
125 | ASSERT_EQ(collection_info.index_columns[0].data_type, DataType::VECTOR_FP32); |
126 | ASSERT_EQ(collection_info.index_columns[0].dimension, 8); |
127 | |
128 | // Insert records |
129 | WriteRequestPtr write_request = WriteRequest::Create(); |
130 | write_request->set_collection_name("test_collection" ); |
131 | write_request->add_forward_columns( |
132 | {"fwd_column1" , "fwd_column2" , "fwd_column3" , "fwd_column4" }); |
133 | write_request->add_index_column("test_column" , DataType::VECTOR_FP32, 8); |
134 | |
135 | for (int i = 0; i < 10; i++) { |
136 | WriteRequest::RowPtr row = write_request->add_row(); |
137 | row->set_primary_key(i); |
138 | row->set_operation_type(OperationType::INSERT); |
139 | row->add_index_value({i + 0.1f, i + 0.2f, i + 0.3f, i + 0.4f, i + 0.5f, |
140 | i + 0.6f, i + 0.7f, i + 0.8f}); // "test_column" |
141 | row->add_forward_value("hello" + std::to_string(i)); // "fwd_column1" |
142 | row->add_forward_value((int64_t)i); // "fwd_column2" |
143 | row->add_forward_value((float)i); // "fwd_column3" |
144 | row->add_forward_value((double)i); // "fwd_column4" |
145 | } |
146 | status = client->write(*write_request); |
147 | ASSERT_EQ(status.reason, "Success" ); |
148 | ASSERT_EQ(status.code, 0); |
149 | |
150 | // Stats collection |
151 | CollectionStats collection_stats; |
152 | status = client->stats_collection("test_collection" , &collection_stats); |
153 | ASSERT_EQ(status.reason, "Success" ); |
154 | ASSERT_EQ(status.code, 0); |
155 | |
156 | ASSERT_EQ(collection_stats.collection_name, "test_collection" ); |
157 | ASSERT_EQ(collection_stats.total_doc_count, 10); |
158 | ASSERT_EQ(collection_stats.total_segment_count, 1); |
159 | ASSERT_EQ(collection_stats.segment_stats.size(), 1); |
160 | ASSERT_EQ(collection_stats.segment_stats[0].doc_count, 10); |
161 | ASSERT_EQ(collection_stats.segment_stats[0].min_primary_key, 0); |
162 | ASSERT_EQ(collection_stats.segment_stats[0].max_primary_key, 9); |
163 | |
164 | // Query records |
165 | QueryRequestPtr query_request = QueryRequest::Create(); |
166 | QueryResponsePtr query_response = QueryResponse::Create(); |
167 | |
168 | query_request->set_collection_name("test_collection" ); |
169 | QueryRequest::KnnQueryParamPtr knn_param = |
170 | query_request->add_knn_query_param(); |
171 | knn_param->set_column_name("test_column" ); |
172 | knn_param->set_topk(10); |
173 | knn_param->set_features({0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); |
174 | |
175 | status = client->query(*query_request, query_response.get()); |
176 | ASSERT_EQ(status.reason, "Success" ); |
177 | ASSERT_EQ(status.code, 0); |
178 | |
179 | ASSERT_EQ(query_response->result_count(), 1); |
180 | auto result = query_response->result(0); |
181 | ASSERT_EQ(result->document_count(), 10); |
182 | for (size_t i = 0; i < result->document_count(); i++) { |
183 | auto doc = result->document(i); |
184 | ASSERT_EQ(doc->primary_key(), i); |
185 | std::string fwd_val1; |
186 | int64_t fwd_val2; |
187 | float fwd_val3; |
188 | double fwd_val4; |
189 | doc->get_forward_value("fwd_column1" , &fwd_val1); |
190 | doc->get_forward_value("fwd_column2" , &fwd_val2); |
191 | doc->get_forward_value("fwd_column3" , &fwd_val3); |
192 | doc->get_forward_value("fwd_column4" , &fwd_val4); |
193 | ASSERT_EQ(fwd_val1, "hello" + std::to_string(i)); |
194 | ASSERT_EQ(fwd_val2, (int64_t)i); |
195 | ASSERT_EQ(fwd_val3, (float)i); |
196 | ASSERT_EQ(fwd_val4, (double)i); |
197 | } |
198 | |
199 | // test wrong forward |
200 | auto doc = result->document(3); |
201 | uint32_t wrong_type_fwd_value1; |
202 | doc->get_forward_value("fwd_column4" , &wrong_type_fwd_value1); |
203 | ASSERT_EQ(wrong_type_fwd_value1, 0); |
204 | |
205 | uint64_t wrong_type_fwd_value2; |
206 | doc->get_forward_value("fwd_column3" , &wrong_type_fwd_value2); |
207 | ASSERT_EQ(wrong_type_fwd_value2, 0); |
208 | |
209 | bool wrong_type_fwd_value3; |
210 | doc->get_forward_value("fwd_column2" , &wrong_type_fwd_value3); |
211 | ASSERT_EQ(wrong_type_fwd_value3, false); |
212 | |
213 | // test insert json format |
214 | WriteRequestPtr write_request2 = WriteRequest::Create(); |
215 | write_request2->set_collection_name("test_collection" ); |
216 | write_request2->add_forward_columns( |
217 | {"fwd_column1" , "fwd_column2" , "fwd_column3" , "fwd_column4" }); |
218 | write_request2->add_index_column("test_column" , DataType::VECTOR_FP32, 8); |
219 | WriteRequest::RowPtr row = write_request2->add_row(); |
220 | row->set_primary_key(10); |
221 | row->set_operation_type(OperationType::INSERT); |
222 | row->add_index_value_by_json( |
223 | "[10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8]" ); |
224 | row->add_forward_value("hello" + std::to_string(10)); |
225 | row->add_forward_value((int64_t)10); |
226 | row->add_forward_value((float)10); |
227 | row->add_forward_value((double)10); |
228 | status = client->write(*write_request2); |
229 | ASSERT_EQ(status.reason, "Success" ); |
230 | ASSERT_EQ(status.code, 0); |
231 | |
232 | // test query json format |
233 | QueryRequestPtr query_request2 = QueryRequest::Create(); |
234 | QueryResponsePtr query_response2 = QueryResponse::Create(); |
235 | query_request2->set_collection_name("test_collection" ); |
236 | auto knn_param2 = query_request2->add_knn_query_param(); |
237 | knn_param2->set_column_name("test_column" ); |
238 | knn_param2->set_topk(10); |
239 | knn_param2->set_features_by_json( |
240 | "[10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8]" ); |
241 | knn_param2->set_data_type(DataType::VECTOR_FP32); |
242 | knn_param2->set_dimension(8); |
243 | status = client->query(*query_request2, query_response2.get()); |
244 | ASSERT_EQ(status.reason, "Success" ); |
245 | ASSERT_EQ(status.code, 0); |
246 | |
247 | ASSERT_EQ(query_response2->result_count(), 1); |
248 | auto result2 = query_response2->result(0); |
249 | ASSERT_EQ(result2->document_count(), 10); |
250 | { |
251 | auto doc = result2->document(0); |
252 | ASSERT_EQ(doc->primary_key(), 10); |
253 | ASSERT_EQ(doc->score(), 0.0f); |
254 | std::string fwd_val1; |
255 | int64_t fwd_val2; |
256 | float fwd_val3; |
257 | double fwd_val4; |
258 | doc->get_forward_value("fwd_column1" , &fwd_val1); |
259 | doc->get_forward_value("fwd_column2" , &fwd_val2); |
260 | doc->get_forward_value("fwd_column3" , &fwd_val3); |
261 | doc->get_forward_value("fwd_column4" , &fwd_val4); |
262 | ASSERT_EQ(fwd_val1, "hello10" ); |
263 | ASSERT_EQ(fwd_val2, (int64_t)10); |
264 | ASSERT_EQ(fwd_val3, (float)10); |
265 | ASSERT_EQ(fwd_val4, (double)10); |
266 | } |
267 | |
268 | // Drop collection |
269 | status = client->drop_collection("test_collection" ); |
270 | ASSERT_EQ(status.reason, "Success" ); |
271 | ASSERT_EQ(status.code, 0); |
272 | |
273 | ret = engine.stop(); |
274 | ASSERT_EQ(ret, 0); |
275 | |
276 | ret = engine.cleanup(); |
277 | ASSERT_EQ(ret, 0); |
278 | } |
279 | |