1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15
16 * \author Haichao.chc
17 * \date Oct 2020
18 * \brief Implementation with grpc protobuf protocol
19 */
20
21#pragma once
22
23#include <map>
24#include <queue>
25
26#ifdef __GNUC__
27#pragma GCC diagnostic push
28#pragma GCC diagnostic ignored "-Wshadow"
29#pragma GCC diagnostic ignored "-Wunused-parameter"
30#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
31#endif
32#include <brpc/channel.h>
33#include <brpc/selective_channel.h>
34#ifdef __GNUC__
35#pragma GCC diagnostic pop
36#endif
37
38#include "proxima_search_client.h"
39
40#ifdef __GNUC__
41#pragma GCC diagnostic push
42#pragma GCC diagnostic ignored "-Wshadow"
43#pragma GCC diagnostic ignored "-Wunused-parameter"
44#endif
45#include "proto/proxima_be.pb.h"
46#ifdef __GNUC__
47#pragma GCC diagnostic pop
48#endif
49
50namespace proxima {
51namespace be {
52
53class PbWriteRequest;
54class PbQueryRequest;
55class PbQueryResponse;
56class PbGetDocumentRequest;
57class PbGetDocumentResponse;
58
59/*
60 * ProximaSearchClient implementation with grpc protobuf protocol
61 */
62class GrpcProximaSearchClient : public ProximaSearchClient {
63 public:
64 //! Constructor
65 GrpcProximaSearchClient() = default;
66
67 //! Destructor
68 ~GrpcProximaSearchClient() override = default;
69
70 //! Connect remote server
71 Status connect(const ChannelOptions &options) override;
72
73 //! Close connection to remote server
74 Status close() override;
75
76 //! Create a collection with config
77 Status create_collection(const CollectionConfig &config) override;
78
79 //! Drop a collection with specific collection name
80 Status drop_collection(const std::string &collection_name) override;
81
82 //! Get collection information with specific collection name
83 Status describe_collection(const std::string &collection_name,
84 CollectionInfo *collection_info) override;
85
86 //! Get collection stastics with specific collection name
87 Status stats_collection(const std::string &collection_name,
88 CollectionStats *stats) override;
89
90 //! Get all the collections' information
91 Status list_collections(std::vector<CollectionInfo> *collections) override;
92
93 //! Write records, including insert/update/delete operations
94 Status write(const WriteRequest &request) override;
95
96 //! Query records, including knn query saerch
97 Status query(const QueryRequest &request, QueryResponse *response) override;
98
99 //! Get specific record with primary key
100 Status get_document_by_key(const GetDocumentRequest &request,
101 GetDocumentResponse *response) override;
102
103 protected:
104 virtual void rpc_create_collection(brpc::Controller *cntl,
105 const proto::CollectionConfig *request,
106 proto::Status *response);
107
108 virtual void rpc_drop_collection(brpc::Controller *cntl,
109 const proto::CollectionName *request,
110 proto::Status *response);
111
112 virtual void rpc_describe_collection(
113 brpc::Controller *cntl, const proto::CollectionName *request,
114 proto::DescribeCollectionResponse *response);
115
116 virtual void rpc_stats_collection(brpc::Controller *cntl,
117 const proto::CollectionName *request,
118 proto::StatsCollectionResponse *response);
119
120 virtual void rpc_list_collections(brpc::Controller *cntl,
121 const proto::ListCondition *request,
122 proto::ListCollectionsResponse *response);
123
124 virtual void rpc_write(brpc::Controller *cntl,
125 const proto::WriteRequest *request,
126 proto::Status *response);
127
128 virtual void rpc_query(brpc::Controller *cntl,
129 const proto::QueryRequest *request,
130 proto::QueryResponse *response);
131
132 virtual void rpc_get_document_by_key(brpc::Controller *cntl,
133 const proto::GetDocumentRequest *request,
134 proto::GetDocumentResponse *response);
135
136 protected:
137 static constexpr uint32_t ErrorCode_InitChannel = 10000;
138 static constexpr uint32_t ErrorCode_RpcError = 10001;
139 static constexpr uint32_t ErrorCode_MismatchedVersion = 10002;
140 static constexpr uint32_t ErrorCode_NotConnected = 10003;
141 static constexpr uint32_t ErrorCode_ValidateError = 10004;
142
143 private:
144 //! Check version
145 bool check_server_version(Status *status);
146
147 //! Convert struct to protobuf object
148 void convert(const CollectionConfig &config,
149 proto::CollectionConfig *pb_request);
150
151 //! Convert struct to protobuf object
152 void convert(const proto::CollectionInfo &pb_response,
153 CollectionInfo *collection_info);
154
155 //! Convert struct to protobuf object
156 void convert(const proto::CollectionStats &pb_response,
157 CollectionStats *collection_stats);
158
159 //! Validate legality of collection config
160 Status validate(const CollectionConfig &config);
161
162 //! Validate legality of write request
163 Status validate(const PbWriteRequest &request);
164
165 //! Validate legality of query request
166 Status validate(const PbQueryRequest &request);
167
168 //! Validate legality of get document request
169 Status validate(const PbGetDocumentRequest &request);
170
171 protected:
172 bool connected_{false};
173
174 private:
175 brpc::SelectiveChannel client_channel_{};
176};
177
178
179/*
180 * WriteRequest implementation with protobuf protocol
181 */
182class PbWriteRequest : public WriteRequest {
183 public:
184 /*
185 * A row describes the format of one record
186 */
187 class PbRow : public WriteRequest::Row {
188 public:
189 //! Constructor
190 PbRow(proto::WriteRequest::Row *p_row) : row_(p_row) {}
191
192 //! Destructor
193 ~PbRow() override = default;
194
195 //! Set primary key, must set
196 void set_primary_key(uint64_t val) override {
197 row_->set_primary_key(val);
198 }
199
200 //! Set operation type, default OP_INSERT
201 void set_operation_type(OperationType op_type) override {
202 row_->set_operation_type((proto::OperationType)op_type);
203 }
204
205 //! Set lsn, optional set, generally used by database repo
206 void set_lsn(uint64_t lsn) override {
207 row_->mutable_lsn_context()->set_lsn(lsn);
208 }
209
210 //! Set lsn context, optional set, generally used by database repo
211 void set_lsn_context(const std::string &lsn_context) override {
212 row_->mutable_lsn_context()->set_context(lsn_context);
213 }
214
215 //! Add forward value, must match forward column names
216 void add_forward_value(const std::string &val) override {
217 row_->mutable_forward_column_values()->add_values()->set_string_value(
218 val);
219 }
220
221 //! Add forward value with bool type
222 void add_forward_value(bool val) override {
223 row_->mutable_forward_column_values()->add_values()->set_bool_value(val);
224 }
225
226 //! Add forward value with int32 type
227 void add_forward_value(int32_t val) override {
228 row_->mutable_forward_column_values()->add_values()->set_int32_value(val);
229 }
230
231 //! Add forward value with int64 type
232 void add_forward_value(int64_t val) override {
233 row_->mutable_forward_column_values()->add_values()->set_int64_value(val);
234 }
235
236 //! Add forward value with uint32 type
237 void add_forward_value(uint32_t val) override {
238 row_->mutable_forward_column_values()->add_values()->set_uint32_value(
239 val);
240 }
241
242 //! Add forward value with uint64 type
243 void add_forward_value(uint64_t val) override {
244 row_->mutable_forward_column_values()->add_values()->set_uint64_value(
245 val);
246 }
247
248 //! Add forward value with float type
249 void add_forward_value(float val) override {
250 row_->mutable_forward_column_values()->add_values()->set_float_value(val);
251 }
252
253 //! Add forward value with double type
254 void add_forward_value(double val) override {
255 row_->mutable_forward_column_values()->add_values()->set_double_value(
256 val);
257 }
258
259 //! Add index value, vector bytes type
260 void add_index_value(const void *val, size_t val_len) override {
261 row_->mutable_index_column_values()->add_values()->set_bytes_value(
262 std::string((const char *)val, val_len));
263 }
264
265 //! Add index value, vector array type
266 void add_index_value(const std::vector<float> &val) override {
267 row_->mutable_index_column_values()->add_values()->set_bytes_value(
268 std::string((const char *)val.data(), val.size() * sizeof(float)));
269 }
270
271 //! Add index value by json format
272 void add_index_value_by_json(const std::string &json_val) override {
273 row_->mutable_index_column_values()->add_values()->set_string_value(
274 json_val);
275 }
276
277 private:
278 proto::WriteRequest::Row *row_{nullptr};
279 };
280
281 public:
282 //! Constructor
283 PbWriteRequest() = default;
284
285 //! Destructor
286 ~PbWriteRequest() override = default;
287
288 //! Set collection name, must set
289 void set_collection_name(const std::string &val) override {
290 request_.set_collection_name(val);
291 }
292
293 //! Add forward column
294 void add_forward_column(const std::string &column_name) override {
295 request_.mutable_row_meta()->add_forward_column_names(column_name);
296 }
297
298 //! Add forward columns
299 void add_forward_columns(
300 const std::vector<std::string> &column_names) override {
301 for (auto &it : column_names) {
302 request_.mutable_row_meta()->add_forward_column_names(it);
303 }
304 }
305
306 //! Add index column
307 void add_index_column(const std::string &column_name, DataType data_type,
308 uint32_t dimension) override {
309 auto *index_column = request_.mutable_row_meta()->add_index_column_metas();
310 index_column->set_column_name(column_name);
311 index_column->set_data_type((proto::DataType)data_type);
312 index_column->set_dimension(dimension);
313 }
314
315 //! Add row data, must add, can't send empty request
316 WriteRequest::RowPtr add_row() override {
317 return std::make_shared<PbWriteRequest::PbRow>(request_.add_rows());
318 }
319
320 //! Set request id for tracelog, optional set
321 void set_request_id(const std::string &request_id) override {
322 request_.set_request_id(request_id);
323 }
324
325 //! Set magic number for validation, optional set
326 void set_magic_number(uint64_t magic_number) override {
327 request_.set_magic_number(magic_number);
328 }
329
330 //! Return raw protobuf data pointer, read only
331 const proto::WriteRequest *data() const {
332 return &request_;
333 }
334
335 private:
336 proto::WriteRequest request_;
337};
338
339/*
340 * QueryRequest implementation with protobuf protocol
341 */
342class PbQueryRequest : public QueryRequest {
343 public:
344 /*
345 * KnnQueryParam implementation with protobuf protocol
346 */
347 class PbKnnQueryParam : public QueryRequest::KnnQueryParam {
348 public:
349 //! Constructor
350 PbKnnQueryParam(proto::QueryRequest::KnnQueryParam *val)
351 : knn_param_(val) {}
352
353 //! Destructor
354 ~PbKnnQueryParam() override = default;
355
356 //! Set column name, must set
357 void set_column_name(const std::string &val) override {
358 knn_param_->set_column_name(val);
359 }
360
361 //! Set topk, must set
362 void set_topk(uint32_t val) override {
363 knn_param_->set_topk(val);
364 }
365
366 //! Set features with vector array format by single
367 void set_features(const void *val, size_t val_len) override {
368 knn_param_->set_batch_count(1);
369 knn_param_->set_features((const char *)val, val_len);
370 }
371
372 //! Set features with vector array format by single
373 void set_features(const std::vector<float> &val) override {
374 knn_param_->set_features(
375 std::string((const char *)val.data(), val.size() * sizeof(float)));
376 knn_param_->set_batch_count(1);
377 knn_param_->set_data_type((proto::DataType)DataType::VECTOR_FP32);
378 knn_param_->set_dimension(val.size());
379 }
380
381 //! Set query vector with bytes format by batch
382 void set_features(const void *val, size_t val_len,
383 uint32_t batch) override {
384 knn_param_->set_batch_count(batch);
385 knn_param_->set_features((const char *)val, val_len);
386 }
387
388 //! Set features by json format
389 void set_features_by_json(const std::string &json_val) override {
390 knn_param_->set_batch_count(1);
391 knn_param_->set_matrix(json_val);
392 }
393
394
395 //! Set features by json format and by batch
396 void set_features_by_json(const std::string &json_val,
397 uint32_t batch) override {
398 knn_param_->set_batch_count(batch);
399 knn_param_->set_matrix(json_val);
400 }
401
402 //! Set search radius, default 0.0f, not open
403 void set_radius(float val) override {
404 knn_param_->set_radius(val);
405 }
406
407 //! Set if use linear search, default false
408 void set_linear(bool val) override {
409 knn_param_->set_is_linear(val);
410 }
411
412 //! Set vector data dimension, must set
413 void set_dimension(uint32_t val) override {
414 knn_param_->set_dimension(val);
415 }
416
417 //! Set vector data type, must set
418 void set_data_type(DataType val) override {
419 knn_param_->set_data_type((proto::DataType)val);
420 }
421
422 //! Add extra params, like ef_search ..etc
423 void add_extra_param(const std::string &key,
424 const std::string &val) override {
425 auto *extra_param = knn_param_->add_extra_params();
426 extra_param->set_key(key);
427 extra_param->set_value(val);
428 }
429
430 private:
431 proto::QueryRequest::KnnQueryParam *knn_param_{nullptr};
432 };
433
434 public:
435 //! Constructor
436 PbQueryRequest() = default;
437
438 //! Destructor
439 ~PbQueryRequest() override = default;
440
441 //! Set collection name, must set
442 void set_collection_name(const std::string &val) override {
443 request_.set_collection_name(val);
444 }
445
446 //! Set debug mode, optional set
447 void set_debug_mode(bool val) override {
448 request_.set_debug_mode(val);
449 }
450
451 //! Set knn query param
452 QueryRequest::KnnQueryParamPtr add_knn_query_param() override {
453 return std::make_shared<PbKnnQueryParam>(request_.mutable_knn_param());
454 }
455
456 //! Return protobuf data pointer, readonly
457 const proto::QueryRequest *data() const {
458 return &request_;
459 }
460
461 private:
462 proto::QueryRequest request_{};
463};
464
465/*
466 * Document implementation with protobuf protocol
467 */
468class PbDocument : public Document {
469 public:
470 //! Constructor
471 PbDocument(const proto::Document *doc_val) {
472 doc_ = doc_val;
473 /// Actually it shoule be in another function to load
474 /// It's safe to iterate input value, so just keep it in construtor
475 for (int i = 0; i < doc_->forward_column_values_size(); i++) {
476 auto &fwd_val = doc_->forward_column_values(i);
477 auto &key = fwd_val.key();
478 auto &val = fwd_val.value();
479 forward_map_.emplace(key, &val);
480 }
481 }
482
483 //! Destructor
484 ~PbDocument() override = default;
485
486 //! Return primary key
487 uint64_t primary_key() const override {
488 return doc_->primary_key();
489 }
490
491 //! Return knn distance score
492 float score() const override {
493 return doc_->score();
494 }
495
496 //! Return forward count
497 size_t forward_count() const override {
498 return forward_map_.size();
499 }
500
501 //! Return forward names
502 void get_forward_names(
503 std::vector<std::string> *forward_names) const override {
504 for (const auto &it : forward_map_) {
505 forward_names->emplace_back(it.first);
506 }
507 }
508
509 //! Return forward value with string type
510 void get_forward_value(const std::string &key,
511 std::string *val) const override {
512 if (forward_map_.find(key) != forward_map_.end()) {
513 *val = forward_map_.at(key)->string_value();
514 }
515 }
516
517 //! Return forward value with bool type
518 void get_forward_value(const std::string &key, bool *val) const override {
519 if (forward_map_.find(key) != forward_map_.end()) {
520 *val = forward_map_.at(key)->bool_value();
521 }
522 }
523
524 //! Return forward value with int32 type
525 void get_forward_value(const std::string &key, int32_t *val) const override {
526 if (forward_map_.find(key) != forward_map_.end()) {
527 *val = forward_map_.at(key)->int32_value();
528 }
529 }
530
531 //! Return forward value with int64 type
532 void get_forward_value(const std::string &key, int64_t *val) const override {
533 if (forward_map_.find(key) != forward_map_.end()) {
534 *val = forward_map_.at(key)->int64_value();
535 }
536 }
537
538 //! Return forward value with uint32 type
539 void get_forward_value(const std::string &key, uint32_t *val) const override {
540 if (forward_map_.find(key) != forward_map_.end()) {
541 *val = forward_map_.at(key)->uint32_value();
542 }
543 }
544
545 //! Return forward value with uint64 type
546 void get_forward_value(const std::string &key, uint64_t *val) const override {
547 if (forward_map_.find(key) != forward_map_.end()) {
548 *val = forward_map_.at(key)->uint64_value();
549 }
550 }
551
552 //! Return forward value with float type
553 void get_forward_value(const std::string &key, float *val) const override {
554 if (forward_map_.find(key) != forward_map_.end()) {
555 *val = forward_map_.at(key)->float_value();
556 }
557 }
558
559 //! Return forward value with double type
560 void get_forward_value(const std::string &key, double *val) const override {
561 if (forward_map_.find(key) != forward_map_.end()) {
562 *val = forward_map_.at(key)->double_value();
563 }
564 }
565
566 private:
567 const proto::Document *doc_{nullptr};
568 std::map<std::string, const proto::GenericValue *> forward_map_{};
569};
570
571/*
572 * QueryResponse implementation of protobuf protocol
573 */
574class PbQueryResponse : public QueryResponse {
575 public:
576 /*
577 * Result implementation of protobuf protocol
578 */
579 class PbResult : public QueryResponse::Result {
580 public:
581 //! Constructor
582 PbResult(const proto::QueryResponse::Result *val) : result_(val) {}
583
584 //! Destructor
585 ~PbResult() override = default;
586
587 //! Return document count
588 size_t document_count() const override {
589 return result_->documents_size();
590 }
591
592 //! Return document pointer of specific pos
593 DocumentPtr document(int index) const override {
594 if (index < result_->documents_size()) {
595 return std::make_shared<PbDocument>(&result_->documents(index));
596 } else {
597 return DocumentPtr();
598 }
599 }
600
601 private:
602 const proto::QueryResponse::Result *result_{nullptr};
603 };
604
605 public:
606 //! Constructor
607 PbQueryResponse() = default;
608
609 //! Destructor
610 ~PbQueryResponse() override = default;
611
612 //! Return debug info
613 const std::string &debug_info() const override {
614 return response_.debug_info();
615 }
616
617 //! Return query latency, microseconds
618 uint64_t latency_us() const override {
619 return response_.latency_us();
620 }
621
622 //! Return batch result count
623 size_t result_count() const override {
624 return response_.results_size();
625 }
626
627 //! Return result of specific batch pos
628 QueryResponse::ResultPtr result(int index) const override {
629 return std::make_shared<PbQueryResponse::PbResult>(
630 &response_.results(index));
631 }
632
633 //! Return protobuf data pointer, readonly
634 proto::QueryResponse *data() {
635 return &response_;
636 }
637
638 private:
639 proto::QueryResponse response_;
640};
641
642/*
643 * GetDocumentRequest implementation of protobuf protocol
644 */
645class PbGetDocumentRequest : public GetDocumentRequest {
646 public:
647 //! Constructor
648 PbGetDocumentRequest() = default;
649
650 //! Destructor
651 ~PbGetDocumentRequest() = default;
652
653 //! Set collection name, must set
654 void set_collection_name(const std::string &val) override {
655 request_.set_collection_name(val);
656 }
657
658 //! Set primary key, must set
659 void set_primary_key(uint64_t val) override {
660 request_.set_primary_key(val);
661 }
662
663 //! Set debug mode, default false
664 void set_debug_mode(bool val) override {
665 request_.set_debug_mode(val);
666 }
667
668 //! Return protobuf data pointer, readonly
669 const proto::GetDocumentRequest *data() const {
670 return &request_;
671 }
672
673 private:
674 proto::GetDocumentRequest request_;
675};
676
677
678/*
679 * GetDocumentResponse implementation of protobuf protocol
680 */
681class PbGetDocumentResponse : public GetDocumentResponse {
682 public:
683 //! Constructor
684 PbGetDocumentResponse() = default;
685
686 //! Destructor
687 ~PbGetDocumentResponse() override = default;
688
689 //! Return debug info
690 const std::string &debug_info() const override {
691 return response_.debug_info();
692 }
693
694 //! If not exist the key, return nullptr, or return document shared ptr
695 DocumentPtr document() const override {
696 if (response_.has_document()) {
697 return std::make_shared<PbDocument>(&response_.document());
698 } else {
699 return DocumentPtr();
700 }
701 }
702
703 //! Return protobuf data pointer
704 proto::GetDocumentResponse *data() {
705 return &response_;
706 }
707
708 private:
709 proto::GetDocumentResponse response_;
710};
711
712
713} // namespace be
714} // end namespace proxima
715