1 | /** |
2 | * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | |
16 | * \author Haichao.chc |
17 | * \date Oct 2020 |
18 | * \brief Client tool that can send benchmark requests |
19 | * to remote proxima be server |
20 | */ |
21 | |
22 | #include <fstream> |
23 | #include <iostream> |
24 | #include <map> |
25 | #include <ailego/parallel/thread_pool.h> |
26 | #include <ailego/utility/string_helper.h> |
27 | #include <ailego/utility/time_helper.h> |
28 | #include <bvar/latency_recorder.h> |
29 | #include <gflags/gflags.h> |
30 | #include <google/protobuf/util/json_util.h> |
31 | #include "common/logger.h" |
32 | #include "common/protobuf_helper.h" |
33 | #include "common/version.h" |
34 | #include "proxima_search_client.h" |
35 | #include "vecs_reader.h" |
36 | |
37 | using namespace proxima::be; |
38 | |
39 | DEFINE_string(command, "" , "Command type: search/insert/delete/update" ); |
40 | DEFINE_string(host, "" , "The host of proxima search engine" ); |
41 | DEFINE_string(collection, "" , "Collection name" ); |
42 | DEFINE_string(column, "" , "Column name" ); |
43 | DEFINE_string(file, "" , "Input file path" ); |
44 | DEFINE_string(protocol, "grpc" , "Protocol http or grpc" ); |
45 | DEFINE_string(data_type, "float" , "Input data type" ); |
46 | DEFINE_uint32(concurrency, 10, "Concurrency connection to server(default 10)" ); |
47 | DEFINE_uint32(topk, 10, "Query topk results(default 10)" ); |
48 | DEFINE_bool(perf, false, "Output perf result" ); |
49 | DEFINE_uint32(rows, 0, "Limit loaded records count" ); |
50 | |
51 | static bool ValidateNotEmpty(const char *flagname, const std::string &value) { |
52 | return !value.empty(); |
53 | } |
54 | |
55 | DEFINE_validator(command, ValidateNotEmpty); |
56 | DEFINE_validator(host, ValidateNotEmpty); |
57 | DEFINE_validator(collection, ValidateNotEmpty); |
58 | DEFINE_validator(file, ValidateNotEmpty); |
59 | |
60 | static inline void PrintUsage() { |
61 | std::cout << "Usage:" << std::endl; |
62 | std::cout << " bench_client <args>" << std::endl << std::endl; |
63 | std::cout << "Args: " << std::endl; |
64 | std::cout |
65 | << " --command Command type: search|insert|delete|update|recall" |
66 | << std::endl; |
67 | std::cout << " --host The host of proxima be" << std::endl; |
68 | std::cout << " --collection Specify collection name" << std::endl; |
69 | std::cout << " --column Specify column name" << std::endl; |
70 | std::cout << " --file Read input data from file" << std::endl; |
71 | std::cout << " --protocol Protocol http or grpc" << std::endl; |
72 | std::cout << " --data_type Support float/binary now (default float)" |
73 | << std::endl; |
74 | std::cout << " --concurrency Send concurrency (default 10)" << std::endl; |
75 | std::cout << " --topk Topk results (default 10)" << std::endl; |
76 | std::cout << " --perf Output perf result (default false)" |
77 | << std::endl; |
78 | std::cout << " --rows Limit loaded rows count" << std::endl; |
79 | std::cout << " --help, -h Display help info" << std::endl; |
80 | std::cout << " --version, -v Display version info" << std::endl; |
81 | } |
82 | |
83 | // Global storage |
84 | struct Record { |
85 | uint64_t key; |
86 | std::string vector; |
87 | std::string attributes; |
88 | uint32_t dimension; |
89 | }; |
90 | |
91 | // Global varibles for recall statistics |
92 | static std::atomic<uint64_t> g_top1_total_count(0U); |
93 | static std::atomic<uint64_t> g_top1_hit_count(0U); |
94 | static std::atomic<uint64_t> g_top10_total_count(0U); |
95 | static std::atomic<uint64_t> g_top10_hit_count(0U); |
96 | static std::atomic<uint64_t> g_top50_total_count(0U); |
97 | static std::atomic<uint64_t> g_top50_hit_count(0U); |
98 | static std::atomic<uint64_t> g_top100_total_count(0U); |
99 | static std::atomic<uint64_t> g_top100_hit_count(0U); |
100 | static std::atomic<uint64_t> g_topk_total_count(0U); |
101 | static std::atomic<uint64_t> g_topk_hit_count(0U); |
102 | |
103 | // Global varibles for qps statistics |
104 | static std::atomic<bool> g_running(false); |
105 | static uint64_t g_min_insert_qps(-1U); |
106 | static uint64_t g_max_insert_qps(0U); |
107 | static uint64_t g_min_search_qps(-1U); |
108 | static uint64_t g_max_search_qps(0U); |
109 | static uint64_t g_min_update_qps(-1U); |
110 | static uint64_t g_max_update_qps(0U); |
111 | static uint64_t g_min_delete_qps(-1U); |
112 | static uint64_t g_max_delete_qps(0U); |
113 | static std::vector<Record> g_record_list; |
114 | static ProximaSearchClientPtr g_client; |
115 | static bvar::LatencyRecorder g_insert_latency_recorder; |
116 | static bvar::LatencyRecorder g_search_latency_recorder; |
117 | static bvar::LatencyRecorder g_update_latency_recorder; |
118 | static bvar::LatencyRecorder g_delete_latency_recorder; |
119 | |
120 | static bool LoadFromTextFile() { |
121 | std::ifstream file_stream(FLAGS_file); |
122 | if (!file_stream.is_open()) { |
123 | LOG_ERROR("Can't open input file %s" , FLAGS_file.c_str()); |
124 | return false; |
125 | } |
126 | |
127 | std::string line; |
128 | while (std::getline(file_stream, line)) { |
129 | line.erase(line.find_last_not_of('\n') + 1); |
130 | if (line.empty()) { |
131 | continue; |
132 | } |
133 | std::vector<std::string> res; |
134 | ailego::StringHelper::Split(line, ';', &res); |
135 | if (res.size() < 1) { |
136 | LOG_ERROR("Bad input line" ); |
137 | continue; |
138 | } |
139 | |
140 | Record record; |
141 | // Parse key |
142 | uint64_t key = std::stoull(res[0]); |
143 | record.key = key; |
144 | // Parse feature |
145 | if (res.size() >= 2) { |
146 | if (FLAGS_data_type == "binary" ) { |
147 | std::vector<uint8_t> vec; |
148 | ailego::StringHelper::Split(res[1], ' ', &vec); |
149 | if (vec.size() == 0 || vec.size() % 32 != 0) { |
150 | LOG_ERROR("Bad feature field" ); |
151 | continue; |
152 | } |
153 | |
154 | std::vector<uint8_t> tmp; |
155 | for (size_t i = 0; i < vec.size(); i += 8) { |
156 | uint8_t v = 0; |
157 | v |= (vec[i] & 0x01) << 7; |
158 | v |= (vec[i + 1] & 0x01) << 6; |
159 | v |= (vec[i + 2] & 0x01) << 5; |
160 | v |= (vec[i + 3] & 0x01) << 4; |
161 | v |= (vec[i + 4] & 0x01) << 3; |
162 | v |= (vec[i + 5] & 0x01) << 2; |
163 | v |= (vec[i + 6] & 0x01) << 1; |
164 | v |= (vec[i + 7] & 0x01) << 0; |
165 | tmp.push_back(v); |
166 | } |
167 | |
168 | record.vector = |
169 | std::string((const char *)tmp.data(), tmp.size() * sizeof(uint8_t)); |
170 | record.dimension = vec.size(); |
171 | } else { |
172 | std::vector<float> feature; |
173 | ailego::StringHelper::Split(res[1], ' ', &feature); |
174 | if (feature.size() == 0) { |
175 | LOG_ERROR("Bad feature field" ); |
176 | continue; |
177 | } |
178 | record.vector = std::string((const char *)feature.data(), |
179 | feature.size() * sizeof(float)); |
180 | record.dimension = feature.size(); |
181 | } |
182 | } |
183 | // Parse attributes |
184 | if (res.size() >= 3) { |
185 | record.attributes = res[2]; |
186 | } |
187 | g_record_list.emplace_back(record); |
188 | |
189 | // Limit rows count |
190 | if (FLAGS_rows > 0 && g_record_list.size() >= FLAGS_rows) { |
191 | break; |
192 | } |
193 | } |
194 | file_stream.close(); |
195 | |
196 | return true; |
197 | } |
198 | |
199 | static bool LoadFromVecsFile() { |
200 | tools::VecsReader reader; |
201 | if (!reader.load(FLAGS_file)) { |
202 | LOG_ERROR("Load vecs file failed." ); |
203 | return false; |
204 | } |
205 | |
206 | for (uint32_t i = 0; i < reader.num_vecs(); i++) { |
207 | Record new_record; |
208 | uint64_t key = reader.get_key(i); |
209 | const char *feature = (const char *)reader.get_vector(i); |
210 | |
211 | new_record.key = key; |
212 | new_record.vector.append(feature, reader.index_meta().element_size()); |
213 | new_record.dimension = reader.index_meta().dimension(); |
214 | g_record_list.emplace_back(new_record); |
215 | |
216 | // Limit rows count |
217 | if (FLAGS_rows > 0 && g_record_list.size() >= FLAGS_rows) { |
218 | break; |
219 | } |
220 | } |
221 | |
222 | return true; |
223 | } |
224 | |
225 | static bool LoadRecords() { |
226 | bool ret; |
227 | if (FLAGS_file.find(".vecs" ) != std::string::npos) { |
228 | ret = LoadFromVecsFile(); |
229 | } else { |
230 | ret = LoadFromTextFile(); |
231 | } |
232 | |
233 | return ret; |
234 | } |
235 | |
236 | static bool InitClient() { |
237 | if (FLAGS_protocol == "http" ) { |
238 | g_client = ProximaSearchClient::Create("HttpClient" ); |
239 | } else if (FLAGS_protocol == "grpc" ) { |
240 | g_client = ProximaSearchClient::Create("GrpcClient" ); |
241 | } else { |
242 | LOG_ERROR("Unknown protocol, only support http or grpc now. protocol[%s]" , |
243 | FLAGS_protocol.c_str()); |
244 | return false; |
245 | } |
246 | |
247 | proxima::be::ChannelOptions options(FLAGS_host); |
248 | options.connection_count = FLAGS_concurrency; |
249 | options.timeout_ms = 60000; |
250 | Status status = g_client->connect(options); |
251 | if (status.code != 0) { |
252 | LOG_ERROR("Connect failed. code[%d] reason[%s]" , status.code, |
253 | status.reason.c_str()); |
254 | return false; |
255 | } |
256 | return true; |
257 | } |
258 | |
259 | static void DoSearchProxima(Record *record) { |
260 | ailego::ElapsedTime timer; |
261 | QueryRequestPtr request = QueryRequest::Create(); |
262 | request->set_collection_name(FLAGS_collection); |
263 | auto knn_param = request->add_knn_query_param(); |
264 | knn_param->set_column_name(FLAGS_column); |
265 | knn_param->set_topk(FLAGS_topk); |
266 | knn_param->set_features(record->vector.c_str(), record->vector.size(), 1); |
267 | if (FLAGS_data_type == "binary" ) { |
268 | knn_param->set_data_type(DataType::VECTOR_BINARY32); |
269 | } else { |
270 | knn_param->set_data_type(DataType::VECTOR_FP32); |
271 | } |
272 | knn_param->set_dimension(record->dimension); |
273 | |
274 | QueryResponsePtr response = QueryResponse::Create(); |
275 | Status status = g_client->query(*request, response.get()); |
276 | if (status.code != 0) { |
277 | LOG_ERROR("Search records failed. query_id[%zu] code[%d] reason[%s] " , |
278 | (size_t)record->key, status.code, status.reason.c_str()); |
279 | return; |
280 | } |
281 | |
282 | auto result = response->result(0); |
283 | std::string result_str; |
284 | for (size_t i = 0; i < result->document_count(); i++) { |
285 | auto doc = result->document(i); |
286 | std::string attr; |
287 | doc->get_forward_value("forward" , &attr); |
288 | if (attr.empty()) { |
289 | ailego::StringHelper::Append(&result_str, " " , doc->primary_key(), ":" , |
290 | doc->score()); |
291 | } else { |
292 | ailego::StringHelper::Append(&result_str, " " , doc->primary_key(), ":" , |
293 | doc->score(), ":" , attr); |
294 | } |
295 | } |
296 | |
297 | uint64_t latency_us = timer.micro_seconds(); |
298 | g_search_latency_recorder << latency_us; |
299 | |
300 | if (!FLAGS_perf) { |
301 | LOG_INFO( |
302 | "Search records success. query_id[%zu] res_num[%zu] results[%s] " |
303 | "rt[%zuus]" , |
304 | (size_t)record->key, result->document_count(), result_str.c_str(), |
305 | (size_t)latency_us); |
306 | } |
307 | } |
308 | |
309 | static void DoInsertProxima(Record *record) { |
310 | ailego::ElapsedTime timer; |
311 | WriteRequestPtr request = WriteRequest::Create(); |
312 | request->set_collection_name(FLAGS_collection); |
313 | if (FLAGS_data_type == "binary" ) { |
314 | request->add_index_column(FLAGS_column, DataType::VECTOR_BINARY32, |
315 | record->dimension); |
316 | } else { |
317 | request->add_index_column(FLAGS_column, DataType::VECTOR_FP32, |
318 | record->dimension); |
319 | } |
320 | // Support forward column temporarily |
321 | if (!record->attributes.empty()) { |
322 | request->add_forward_column("forward" ); |
323 | } |
324 | |
325 | auto row = request->add_row(); |
326 | row->set_operation_type(OperationType::INSERT); |
327 | row->set_primary_key(record->key); |
328 | row->add_index_value(record->vector.c_str(), record->vector.size()); |
329 | // Support forward column temporarily |
330 | if (!record->attributes.empty()) { |
331 | row->add_forward_value(record->attributes); |
332 | } |
333 | |
334 | Status status = g_client->write(*request); |
335 | if (status.code != 0) { |
336 | LOG_ERROR("Insert record failed. key[%zu] code[%d] reason[%s]" , |
337 | (size_t)record->key, status.code, status.reason.c_str()); |
338 | return; |
339 | } |
340 | |
341 | uint64_t latency_us = timer.micro_seconds(); |
342 | g_insert_latency_recorder << latency_us; |
343 | |
344 | if (!FLAGS_perf) { |
345 | LOG_INFO("Insert record success. key[%zu] rt[%zuus]" , (size_t)record->key, |
346 | (size_t)latency_us); |
347 | } |
348 | } |
349 | |
350 | static void DoDeleteProxima(Record *record) { |
351 | ailego::ElapsedTime timer; |
352 | WriteRequestPtr request = WriteRequest::Create(); |
353 | request->set_collection_name(FLAGS_collection); |
354 | auto row = request->add_row(); |
355 | row->set_operation_type(OperationType::DELETE); |
356 | row->set_primary_key(record->key); |
357 | row->add_index_value(record->vector.c_str(), record->vector.size()); |
358 | |
359 | Status status = g_client->write(*request); |
360 | if (status.code != 0) { |
361 | LOG_ERROR("Delete record failed. key[%zu] code[%d] reason[%s]" , |
362 | (size_t)record->key, status.code, status.reason.c_str()); |
363 | return; |
364 | } |
365 | |
366 | uint64_t latency_us = timer.micro_seconds(); |
367 | g_delete_latency_recorder << latency_us; |
368 | |
369 | if (!FLAGS_perf) { |
370 | LOG_INFO("Delete record success. key[%zu] rt[%zuus]" , (size_t)record->key, |
371 | (size_t)latency_us); |
372 | } |
373 | } |
374 | |
375 | static void DoUpdateProxima(Record *record) { |
376 | ailego::ElapsedTime timer; |
377 | WriteRequestPtr request = WriteRequest::Create(); |
378 | request->set_collection_name(FLAGS_collection); |
379 | if (FLAGS_data_type == "binary" ) { |
380 | request->add_index_column(FLAGS_column, DataType::VECTOR_BINARY32, |
381 | record->dimension); |
382 | } else { |
383 | request->add_index_column(FLAGS_column, DataType::VECTOR_FP32, |
384 | record->dimension); |
385 | } |
386 | auto row = request->add_row(); |
387 | row->set_operation_type(OperationType::UPDATE); |
388 | row->set_primary_key(record->key); |
389 | row->add_index_value(record->vector.c_str(), record->vector.size()); |
390 | |
391 | Status status = g_client->write(*request); |
392 | if (status.code != 0) { |
393 | LOG_ERROR("Update record failed. key[%zu] code[%d] reason[%s]" , |
394 | (size_t)record->key, status.code, status.reason.c_str()); |
395 | return; |
396 | } |
397 | |
398 | uint64_t latency_us = timer.micro_seconds(); |
399 | g_update_latency_recorder << latency_us; |
400 | |
401 | if (!FLAGS_perf) { |
402 | LOG_INFO("Update record success. key[%zu] rt[%zuus]" , (size_t)record->key, |
403 | (size_t)latency_us); |
404 | } |
405 | } |
406 | |
407 | #define ADD_RECALL_COUNT(topk, total_count, hit_count) \ |
408 | for (size_t i = 0; i < topk && i < result1->document_count(); i++) { \ |
409 | total_count++; \ |
410 | auto doc1 = result1->document(i); \ |
411 | for (size_t j = 0; j < topk && j < result2->document_count(); j++) { \ |
412 | auto doc2 = result2->document(j); \ |
413 | if (doc1->primary_key() == doc2->primary_key() || \ |
414 | doc1->score() == doc2->score()) { \ |
415 | hit_count++; \ |
416 | break; \ |
417 | } \ |
418 | } \ |
419 | } |
420 | |
421 | |
422 | static void DoRecallProxima(Record *record) { |
423 | QueryRequestPtr request = QueryRequest::Create(); |
424 | request->set_collection_name(FLAGS_collection); |
425 | auto knn_param = request->add_knn_query_param(); |
426 | knn_param->set_column_name(FLAGS_column); |
427 | knn_param->set_topk(FLAGS_topk); |
428 | knn_param->set_features(record->vector.c_str(), record->vector.size(), 1); |
429 | if (FLAGS_data_type == "binary" ) { |
430 | knn_param->set_data_type(DataType::VECTOR_BINARY32); |
431 | } else { |
432 | knn_param->set_data_type(DataType::VECTOR_FP32); |
433 | } |
434 | knn_param->set_dimension(record->dimension); |
435 | |
436 | // 1. get knn results |
437 | auto response1 = QueryResponse::Create(); |
438 | Status status = g_client->query(*request, response1.get()); |
439 | if (status.code != 0) { |
440 | LOG_ERROR("Knn search records failed. query_id[%zu] code[%d] reason[%s] " , |
441 | (size_t)record->key, status.code, status.reason.c_str()); |
442 | return; |
443 | } |
444 | auto result1 = response1->result(0); |
445 | |
446 | // 2. get linear knn results |
447 | knn_param->set_linear(true); |
448 | auto response2 = QueryResponse::Create(); |
449 | status = g_client->query(*request, response2.get()); |
450 | if (status.code != 0) { |
451 | LOG_ERROR( |
452 | "Linear search records failed. query_id[%zu] code[%d] reason[%s] " , |
453 | (size_t)record->key, status.code, status.reason.c_str()); |
454 | return; |
455 | } |
456 | auto result2 = response2->result(0); |
457 | |
458 | if (result1->document_count() != result2->document_count()) { |
459 | LOG_ERROR( |
460 | "Knn search results count mismatch linear search results. result1[%lu] " |
461 | "result2[%lu]" , |
462 | result1->document_count(), result2->document_count()); |
463 | return; |
464 | } |
465 | |
466 | if (FLAGS_topk > 1) { |
467 | ADD_RECALL_COUNT(1, g_top1_total_count, g_top1_hit_count); |
468 | } |
469 | |
470 | if (FLAGS_topk > 10) { |
471 | ADD_RECALL_COUNT(10, g_top10_total_count, g_top10_hit_count); |
472 | } |
473 | |
474 | if (FLAGS_topk > 50) { |
475 | ADD_RECALL_COUNT(50, g_top50_total_count, g_top50_hit_count); |
476 | } |
477 | |
478 | if (FLAGS_topk > 100) { |
479 | ADD_RECALL_COUNT(100, g_top100_total_count, g_top100_hit_count); |
480 | } |
481 | |
482 | ADD_RECALL_COUNT(FLAGS_topk, g_topk_total_count, g_topk_hit_count); |
483 | } |
484 | |
485 | #define OUTPUT_PERF_RESULT(recorder, max_qps, min_qps) \ |
486 | std::cout << "====================PERFORMANCE======================" \ |
487 | << std::endl; \ |
488 | std::cout << "Process count : " << recorder.count() << std::endl; \ |
489 | std::cout << "Average qps : " << recorder.qps() << "/s" << std::endl; \ |
490 | std::cout << "Maximum qps : " << max_qps << "/s" << std::endl; \ |
491 | std::cout << "Minimum qps : " << min_qps << "/s" << std::endl; \ |
492 | std::cout << "Average latency: " << recorder.latency() << "us" << std::endl; \ |
493 | std::cout << "Maximum latency: " << recorder.max_latency() << "us" \ |
494 | << std::endl; \ |
495 | std::cout << "Percentile @1 : " << recorder.latency_percentile(0.01) \ |
496 | << "us" << std::endl; \ |
497 | std::cout << "Percentile @10 : " << recorder.latency_percentile(0.10) \ |
498 | << "us" << std::endl; \ |
499 | std::cout << "Percentile @25 : " << recorder.latency_percentile(0.25) \ |
500 | << "us" << std::endl; \ |
501 | std::cout << "Percentile @50 : " << recorder.latency_percentile(0.50) \ |
502 | << "us" << std::endl; \ |
503 | std::cout << "Percentile @75 : " << recorder.latency_percentile(0.75) \ |
504 | << "us" << std::endl; \ |
505 | std::cout << "Percentile @90 : " << recorder.latency_percentile(0.90) \ |
506 | << "us" << std::endl; \ |
507 | std::cout << "Percentile @95 : " << recorder.latency_percentile(0.95) \ |
508 | << "us" << std::endl; \ |
509 | std::cout << "Percentile @99 : " << recorder.latency_percentile(0.99) \ |
510 | << "us" << std::endl; |
511 | |
512 | static void SearchRecords() { |
513 | if (FLAGS_column.empty()) { |
514 | LOG_ERROR("Input argument column can't be emtpy" ); |
515 | return; |
516 | } |
517 | |
518 | ailego::ThreadPool thread_pool(FLAGS_concurrency, false); |
519 | for (size_t i = 0; i < g_record_list.size(); i++) { |
520 | thread_pool.execute(DoSearchProxima, &g_record_list[i]); |
521 | |
522 | while (thread_pool.pending_count() > 1000) { |
523 | usleep(1000); |
524 | thread_pool.wake_all(); |
525 | } |
526 | } |
527 | thread_pool.wait_finish(); |
528 | |
529 | if (FLAGS_perf) { |
530 | OUTPUT_PERF_RESULT(g_search_latency_recorder, g_max_search_qps, |
531 | g_min_search_qps); |
532 | } |
533 | } |
534 | |
535 | static void InsertRecords() { |
536 | if (FLAGS_column.empty()) { |
537 | LOG_ERROR("Input argument column can't be emtpy" ); |
538 | return; |
539 | } |
540 | |
541 | ailego::ThreadPool thread_pool(FLAGS_concurrency, false); |
542 | for (size_t i = 0; i < g_record_list.size(); i++) { |
543 | thread_pool.execute(DoInsertProxima, &g_record_list[i]); |
544 | while (thread_pool.pending_count() > 1000) { |
545 | usleep(1000); |
546 | thread_pool.wake_all(); |
547 | } |
548 | } |
549 | thread_pool.wait_finish(); |
550 | |
551 | if (FLAGS_perf) { |
552 | OUTPUT_PERF_RESULT(g_insert_latency_recorder, g_max_insert_qps, |
553 | g_min_insert_qps); |
554 | } |
555 | } |
556 | |
557 | static void DeleteRecords() { |
558 | ailego::ThreadPool thread_pool(FLAGS_concurrency, false); |
559 | for (size_t i = 0; i < g_record_list.size(); i++) { |
560 | thread_pool.execute(DoDeleteProxima, &g_record_list[i]); |
561 | while (thread_pool.pending_count() > 1000) { |
562 | usleep(1000); |
563 | thread_pool.wake_all(); |
564 | } |
565 | } |
566 | thread_pool.wait_finish(); |
567 | |
568 | if (FLAGS_perf) { |
569 | OUTPUT_PERF_RESULT(g_delete_latency_recorder, g_max_delete_qps, |
570 | g_min_delete_qps); |
571 | } |
572 | } |
573 | |
574 | static void UpdateRecords() { |
575 | if (FLAGS_column.empty()) { |
576 | LOG_ERROR("Input argument column can't be emtpy" ); |
577 | return; |
578 | } |
579 | |
580 | ailego::ThreadPool thread_pool(FLAGS_concurrency, false); |
581 | for (size_t i = 0; i < g_record_list.size(); i++) { |
582 | thread_pool.execute(DoUpdateProxima, &g_record_list[i]); |
583 | while (thread_pool.pending_count() > 1000) { |
584 | usleep(1000); |
585 | thread_pool.wake_all(); |
586 | } |
587 | } |
588 | thread_pool.wait_finish(); |
589 | |
590 | if (FLAGS_perf) { |
591 | OUTPUT_PERF_RESULT(g_update_latency_recorder, g_max_update_qps, |
592 | g_min_update_qps); |
593 | } |
594 | } |
595 | |
596 | static void RecallRecords() { |
597 | if (FLAGS_column.empty()) { |
598 | LOG_ERROR("Input argument column can't be emtpy" ); |
599 | return; |
600 | } |
601 | |
602 | ailego::ThreadPool thread_pool(FLAGS_concurrency, false); |
603 | for (size_t i = 0; i < g_record_list.size(); i++) { |
604 | thread_pool.execute(DoRecallProxima, &g_record_list[i]); |
605 | while (thread_pool.pending_count() > 1000) { |
606 | usleep(1000); |
607 | thread_pool.wake_all(); |
608 | } |
609 | } |
610 | thread_pool.wait_finish(); |
611 | |
612 | // Output recall ratio |
613 | if (FLAGS_topk > 1) { |
614 | double top1_hit_ratio = g_top1_total_count > 0 |
615 | ? (double)g_top1_hit_count / g_top1_total_count |
616 | : 0.0f; |
617 | std::cout << "Recall @1: " << top1_hit_ratio << std::endl; |
618 | } |
619 | |
620 | if (FLAGS_topk > 10) { |
621 | double top10_hit_ratio = |
622 | g_top10_total_count > 0 |
623 | ? (double)g_top10_hit_count / g_top10_total_count |
624 | : 0.0f; |
625 | std::cout << "Recall @10: " << top10_hit_ratio << std::endl; |
626 | } |
627 | |
628 | if (FLAGS_topk > 50) { |
629 | double top50_hit_ratio = |
630 | g_top50_total_count > 0 |
631 | ? (double)g_top50_hit_count / g_top50_total_count |
632 | : 0.0f; |
633 | std::cout << "Recall @50: " << top50_hit_ratio << std::endl; |
634 | } |
635 | |
636 | if (FLAGS_topk > 100) { |
637 | double top100_hit_ratio = |
638 | g_top100_total_count > 0 |
639 | ? (double)g_top100_hit_count / g_top100_total_count |
640 | : 0.0f; |
641 | std::cout << "Recall @100: " << top100_hit_ratio << std::endl; |
642 | } |
643 | |
644 | double topk_hit_ratio = g_topk_total_count > 0 |
645 | ? (double)g_topk_hit_count / g_topk_total_count |
646 | : 0.0f; |
647 | std::cout << "Recall @" << FLAGS_topk << ": " << topk_hit_ratio << std::endl; |
648 | } |
649 | |
650 | static void Monitor() { |
651 | std::this_thread::sleep_for(std::chrono::seconds(5)); |
652 | while (g_running) { |
653 | std::this_thread::sleep_for(std::chrono::seconds(1)); |
654 | if (FLAGS_command == "search" ) { |
655 | uint64_t qps = (uint64_t)g_search_latency_recorder.qps(1); |
656 | if (qps > g_max_search_qps) { |
657 | g_max_search_qps = qps; |
658 | } |
659 | if (qps < g_min_search_qps) { |
660 | g_min_search_qps = qps; |
661 | } |
662 | } else if (FLAGS_command == "insert" ) { |
663 | uint64_t qps = (uint64_t)g_insert_latency_recorder.qps(1); |
664 | if (qps > g_max_insert_qps) { |
665 | g_max_insert_qps = qps; |
666 | } |
667 | if (qps < g_min_insert_qps) { |
668 | g_min_insert_qps = qps; |
669 | } |
670 | } else if (FLAGS_command == "update" ) { |
671 | uint64_t qps = (uint64_t)g_update_latency_recorder.qps(1); |
672 | if (qps > g_max_update_qps) { |
673 | g_max_update_qps = qps; |
674 | } |
675 | if (qps < g_min_update_qps) { |
676 | g_min_update_qps = qps; |
677 | } |
678 | } else if (FLAGS_command == "delete" ) { |
679 | uint64_t qps = (uint64_t)g_delete_latency_recorder.qps(1); |
680 | if (qps > g_max_delete_qps) { |
681 | g_max_delete_qps = qps; |
682 | } |
683 | if (qps < g_min_delete_qps) { |
684 | g_min_delete_qps = qps; |
685 | } |
686 | } |
687 | } |
688 | } |
689 | |
690 | int main(int argc, char **argv) { |
691 | // Parse arguments |
692 | for (int i = 1; i < argc; ++i) { |
693 | const char *arg = argv[i]; |
694 | if (!strcmp(arg, "-help" ) || !strcmp(arg, "--help" ) || !strcmp(arg, "-h" )) { |
695 | PrintUsage(); |
696 | exit(0); |
697 | } else if (!strcmp(arg, "-version" ) || !strcmp(arg, "--version" ) || |
698 | !strcmp(arg, "-v" )) { |
699 | std::cout << proxima::be::Version::Details() << std::endl; |
700 | exit(0); |
701 | } |
702 | } |
703 | gflags::ParseCommandLineNonHelpFlags(&argc, &argv, false); |
704 | |
705 | // Init client channel |
706 | if (!InitClient()) { |
707 | LOG_ERROR("Init client failed. host[%s]" , FLAGS_host.c_str()); |
708 | exit(1); |
709 | } |
710 | |
711 | // Load data from input file |
712 | if (!LoadRecords()) { |
713 | LOG_ERROR("Load data from file failed. file[%s]" , FLAGS_file.c_str()); |
714 | exit(1); |
715 | } |
716 | std::cout << "Load data complete. num[" << g_record_list.size() << "]" |
717 | << std::endl; |
718 | g_running = true; |
719 | |
720 | // Add monitor thread |
721 | std::thread *monitor_thread = nullptr; |
722 | if (FLAGS_perf) { |
723 | monitor_thread = new std::thread(Monitor); |
724 | } |
725 | |
726 | // Register commands |
727 | std::map<std::string, std::function<void(void)>> record_ops = { |
728 | {"search" , SearchRecords}, |
729 | {"insert" , InsertRecords}, |
730 | {"update" , UpdateRecords}, |
731 | {"delete" , DeleteRecords}, |
732 | {"recall" , RecallRecords}}; |
733 | if (record_ops.find(FLAGS_command) != record_ops.end()) { |
734 | record_ops[FLAGS_command](); |
735 | } else { |
736 | LOG_ERROR("Unsupported command type: %s" , FLAGS_command.c_str()); |
737 | exit(1); |
738 | } |
739 | |
740 | g_running = false; |
741 | if (monitor_thread) { |
742 | monitor_thread->join(); |
743 | delete monitor_thread; |
744 | } |
745 | |
746 | return 0; |
747 | } |
748 | |
749 | |
750 | #undef OUTPUT_PERF_RESULT |
751 | |