1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define private public
18#define protected public
19#include "common/config.h"
20#undef private
21#undef protected
22
23#include <ailego/utility/file_helper.h>
24#include <gtest/gtest.h>
25#include "common/defer.h"
26#include "server/proxima_search_engine.h"
27#include "port_helper.h"
28#include "proxima_search_client.h"
29
30#ifdef proxima_search_engine_test_VERSION
31#define PROXIMA_BE_VERSION_STRING proxima_search_engine_test_VERSION
32#else
33#define PROXIMA_BE_VERSION_STRING "unknown"
34#endif
35
36using namespace proxima::be;
37using namespace proxima::be::server;
38
39class ProximaSearchEngineTest : public ::testing::Test {
40 protected:
41 virtual void SetUp() {
42 ailego::FileHelper::RemoveDirectory("./test_proxima_be/");
43
44 int pid;
45 PortHelper::GetPort(&grpc_port_, &pid);
46 PortHelper::GetPort(&http_port_, &pid);
47 PortHelper::RemovePortFile(pid);
48
49 auto &config = Config::Instance();
50 config.config_.mutable_common_config()->set_logger_type("ConsoleLogger");
51 config.config_.mutable_common_config()->set_log_directory(
52 "./test_proxima_be/log/");
53 config.config_.mutable_common_config()->set_protocol("grpc");
54 config.config_.mutable_common_config()->set_grpc_listen_port(grpc_port_);
55 config.config_.mutable_common_config()->set_http_listen_port(http_port_);
56 config.config_.mutable_index_config()->set_index_directory(
57 "./test_proxima_be/index_data/");
58 std::string work_directory;
59 ailego::FileHelper::GetWorkingDirectory(&work_directory);
60 std::string meta_uri = "sqlite://" + work_directory +
61 "/test_proxima_be/proxima_be_meta.sqlite";
62 config.config_.mutable_meta_config()->set_meta_uri(meta_uri);
63 }
64
65 virtual void TearDown() {}
66
67 protected:
68 int grpc_port_;
69 int http_port_;
70};
71
72TEST_F(ProximaSearchEngineTest, TestClient) {
73 auto &engine = ProximaSearchEngine::Instance();
74
75 int ret = engine.init(false, "");
76 ASSERT_EQ(ret, 0);
77
78 Defer defer([&engine] {
79 engine.stop();
80 engine.cleanup();
81 });
82
83 engine.set_version(PROXIMA_BE_VERSION_STRING);
84 ret = engine.start();
85 ASSERT_EQ(ret, 0);
86
87 // Create a client
88 ProximaSearchClientPtr client = ProximaSearchClient::Create();
89 ASSERT_TRUE(client != nullptr);
90
91 // Connect to server
92 ChannelOptions options(std::string("127.0.0.1:") +
93 std::to_string(grpc_port_));
94 options.timeout_ms = 60000U;
95 Status status = client->connect(options);
96 ASSERT_EQ(status.reason, "Success");
97 ASSERT_EQ(status.code, 0);
98
99 // Create collection
100 CollectionConfig config;
101 config.collection_name = "test_collection";
102 config.forward_columns = {"fwd_column1", "fwd_column2", "fwd_column3",
103 "fwd_column4"};
104 config.index_columns = {
105 IndexColumnParam("test_column", DataType::VECTOR_FP32, 8)};
106 status = client->create_collection(config);
107 ASSERT_EQ(status.reason, "Success");
108 ASSERT_EQ(status.code, 0);
109
110 // Describe collection
111 CollectionInfo collection_info;
112 status = client->describe_collection("test_collection", &collection_info);
113 ASSERT_EQ(status.reason, "Success");
114 ASSERT_EQ(status.code, 0);
115 ASSERT_EQ(collection_info.collection_name, "test_collection");
116 ASSERT_EQ(collection_info.forward_columns.size(), 4);
117 ASSERT_EQ(collection_info.forward_columns[0], "fwd_column1");
118 ASSERT_EQ(collection_info.forward_columns[1], "fwd_column2");
119 ASSERT_EQ(collection_info.forward_columns[2], "fwd_column3");
120 ASSERT_EQ(collection_info.forward_columns[3], "fwd_column4");
121 ASSERT_EQ(collection_info.index_columns.size(), 1);
122 ASSERT_EQ(collection_info.index_columns[0].column_name, "test_column");
123 ASSERT_EQ(collection_info.index_columns[0].index_type,
124 IndexType::PROXIMA_GRAPH_INDEX);
125 ASSERT_EQ(collection_info.index_columns[0].data_type, DataType::VECTOR_FP32);
126 ASSERT_EQ(collection_info.index_columns[0].dimension, 8);
127
128 // Insert records
129 WriteRequestPtr write_request = WriteRequest::Create();
130 write_request->set_collection_name("test_collection");
131 write_request->add_forward_columns(
132 {"fwd_column1", "fwd_column2", "fwd_column3", "fwd_column4"});
133 write_request->add_index_column("test_column", DataType::VECTOR_FP32, 8);
134
135 for (int i = 0; i < 10; i++) {
136 WriteRequest::RowPtr row = write_request->add_row();
137 row->set_primary_key(i);
138 row->set_operation_type(OperationType::INSERT);
139 row->add_index_value({i + 0.1f, i + 0.2f, i + 0.3f, i + 0.4f, i + 0.5f,
140 i + 0.6f, i + 0.7f, i + 0.8f}); // "test_column"
141 row->add_forward_value("hello" + std::to_string(i)); // "fwd_column1"
142 row->add_forward_value((int64_t)i); // "fwd_column2"
143 row->add_forward_value((float)i); // "fwd_column3"
144 row->add_forward_value((double)i); // "fwd_column4"
145 }
146 status = client->write(*write_request);
147 ASSERT_EQ(status.reason, "Success");
148 ASSERT_EQ(status.code, 0);
149
150 // Stats collection
151 CollectionStats collection_stats;
152 status = client->stats_collection("test_collection", &collection_stats);
153 ASSERT_EQ(status.reason, "Success");
154 ASSERT_EQ(status.code, 0);
155
156 ASSERT_EQ(collection_stats.collection_name, "test_collection");
157 ASSERT_EQ(collection_stats.total_doc_count, 10);
158 ASSERT_EQ(collection_stats.total_segment_count, 1);
159 ASSERT_EQ(collection_stats.segment_stats.size(), 1);
160 ASSERT_EQ(collection_stats.segment_stats[0].doc_count, 10);
161 ASSERT_EQ(collection_stats.segment_stats[0].min_primary_key, 0);
162 ASSERT_EQ(collection_stats.segment_stats[0].max_primary_key, 9);
163
164 // Query records
165 QueryRequestPtr query_request = QueryRequest::Create();
166 QueryResponsePtr query_response = QueryResponse::Create();
167
168 query_request->set_collection_name("test_collection");
169 QueryRequest::KnnQueryParamPtr knn_param =
170 query_request->add_knn_query_param();
171 knn_param->set_column_name("test_column");
172 knn_param->set_topk(10);
173 knn_param->set_features({0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
174
175 status = client->query(*query_request, query_response.get());
176 ASSERT_EQ(status.reason, "Success");
177 ASSERT_EQ(status.code, 0);
178
179 ASSERT_EQ(query_response->result_count(), 1);
180 auto result = query_response->result(0);
181 ASSERT_EQ(result->document_count(), 10);
182 for (size_t i = 0; i < result->document_count(); i++) {
183 auto doc = result->document(i);
184 ASSERT_EQ(doc->primary_key(), i);
185 std::string fwd_val1;
186 int64_t fwd_val2;
187 float fwd_val3;
188 double fwd_val4;
189 doc->get_forward_value("fwd_column1", &fwd_val1);
190 doc->get_forward_value("fwd_column2", &fwd_val2);
191 doc->get_forward_value("fwd_column3", &fwd_val3);
192 doc->get_forward_value("fwd_column4", &fwd_val4);
193 ASSERT_EQ(fwd_val1, "hello" + std::to_string(i));
194 ASSERT_EQ(fwd_val2, (int64_t)i);
195 ASSERT_EQ(fwd_val3, (float)i);
196 ASSERT_EQ(fwd_val4, (double)i);
197 }
198
199 // test wrong forward
200 auto doc = result->document(3);
201 uint32_t wrong_type_fwd_value1;
202 doc->get_forward_value("fwd_column4", &wrong_type_fwd_value1);
203 ASSERT_EQ(wrong_type_fwd_value1, 0);
204
205 uint64_t wrong_type_fwd_value2;
206 doc->get_forward_value("fwd_column3", &wrong_type_fwd_value2);
207 ASSERT_EQ(wrong_type_fwd_value2, 0);
208
209 bool wrong_type_fwd_value3;
210 doc->get_forward_value("fwd_column2", &wrong_type_fwd_value3);
211 ASSERT_EQ(wrong_type_fwd_value3, false);
212
213 // test insert json format
214 WriteRequestPtr write_request2 = WriteRequest::Create();
215 write_request2->set_collection_name("test_collection");
216 write_request2->add_forward_columns(
217 {"fwd_column1", "fwd_column2", "fwd_column3", "fwd_column4"});
218 write_request2->add_index_column("test_column", DataType::VECTOR_FP32, 8);
219 WriteRequest::RowPtr row = write_request2->add_row();
220 row->set_primary_key(10);
221 row->set_operation_type(OperationType::INSERT);
222 row->add_index_value_by_json(
223 "[10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8]");
224 row->add_forward_value("hello" + std::to_string(10));
225 row->add_forward_value((int64_t)10);
226 row->add_forward_value((float)10);
227 row->add_forward_value((double)10);
228 status = client->write(*write_request2);
229 ASSERT_EQ(status.reason, "Success");
230 ASSERT_EQ(status.code, 0);
231
232 // test query json format
233 QueryRequestPtr query_request2 = QueryRequest::Create();
234 QueryResponsePtr query_response2 = QueryResponse::Create();
235 query_request2->set_collection_name("test_collection");
236 auto knn_param2 = query_request2->add_knn_query_param();
237 knn_param2->set_column_name("test_column");
238 knn_param2->set_topk(10);
239 knn_param2->set_features_by_json(
240 "[10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8]");
241 knn_param2->set_data_type(DataType::VECTOR_FP32);
242 knn_param2->set_dimension(8);
243 status = client->query(*query_request2, query_response2.get());
244 ASSERT_EQ(status.reason, "Success");
245 ASSERT_EQ(status.code, 0);
246
247 ASSERT_EQ(query_response2->result_count(), 1);
248 auto result2 = query_response2->result(0);
249 ASSERT_EQ(result2->document_count(), 10);
250 {
251 auto doc = result2->document(0);
252 ASSERT_EQ(doc->primary_key(), 10);
253 ASSERT_EQ(doc->score(), 0.0f);
254 std::string fwd_val1;
255 int64_t fwd_val2;
256 float fwd_val3;
257 double fwd_val4;
258 doc->get_forward_value("fwd_column1", &fwd_val1);
259 doc->get_forward_value("fwd_column2", &fwd_val2);
260 doc->get_forward_value("fwd_column3", &fwd_val3);
261 doc->get_forward_value("fwd_column4", &fwd_val4);
262 ASSERT_EQ(fwd_val1, "hello10");
263 ASSERT_EQ(fwd_val2, (int64_t)10);
264 ASSERT_EQ(fwd_val3, (float)10);
265 ASSERT_EQ(fwd_val4, (double)10);
266 }
267
268 // Drop collection
269 status = client->drop_collection("test_collection");
270 ASSERT_EQ(status.reason, "Success");
271 ASSERT_EQ(status.code, 0);
272
273 ret = engine.stop();
274 ASSERT_EQ(ret, 0);
275
276 ret = engine.cleanup();
277 ASSERT_EQ(ret, 0);
278}
279