1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15
16 * \author Hechong.xyf
17 * \date May 2019
18 * \brief Interface of AiTheta Index Measure
19 */
20
21#ifndef __AITHETA2_INDEX_MEASURE_H__
22#define __AITHETA2_INDEX_MEASURE_H__
23
24#include "index_meta.h"
25#include "index_module.h"
26
27namespace aitheta2 {
28
29/*! Index Measure
30 */
31struct IndexMeasure : public IndexModule {
32 //! Index Measure Pointer
33 typedef std::shared_ptr<IndexMeasure> Pointer;
34
35 //! Matrix Distance Function
36 typedef void (*MatrixDistanceHandle)(const void *m, const void *q, size_t dim,
37 float *out);
38
39 //! Matrix Distance Function Object
40 using MatrixDistance =
41 std::function<void(const void *m, const void *q, size_t dim, float *out)>;
42
43 //! Destructor
44 virtual ~IndexMeasure(void) {}
45
46 //! Initialize Measure
47 virtual int init(const IndexMeta &meta, const IndexParams &params) = 0;
48
49 //! Cleanup Measure
50 virtual int cleanup(void) = 0;
51
52 //! Retrieve if it matched
53 virtual bool is_matched(const IndexMeta &meta) const = 0;
54
55 //! Retrieve if it matched
56 virtual bool is_matched(const IndexMeta &meta,
57 const IndexQueryMeta &qmeta) const = 0;
58
59 //! Retrieve distance function for query
60 virtual MatrixDistance distance(void) const = 0;
61
62 //! Retrieve distance function for index features
63 virtual MatrixDistance distance_matrix(size_t m, size_t n) const = 0;
64
65 //! Retrieve params of Measure
66 virtual const IndexParams &params(void) const = 0;
67
68 //! Retrieve query measure object of this index measure
69 virtual Pointer query_measure(void) const = 0;
70
71 //! Normalize result
72 virtual void normalize(float *score) const {
73 (void)score;
74 }
75
76 //! Retrieve if it supports normalization
77 virtual bool support_normalize(void) const {
78 return false;
79 }
80
81 //! Train the measure
82 virtual int train(const void *vec, size_t dim) {
83 (void)vec;
84 (void)dim;
85 return 0;
86 }
87
88 //! Retrieve if it supports training
89 virtual bool support_train(void) const {
90 return false;
91 }
92
93 //! Compute the distance between feature and query
94 float distance(const void *m, const void *q, size_t dim) const {
95 float dist;
96 (this->distance())(m, q, dim, &dist);
97 return dist;
98 }
99};
100
101} // namespace aitheta2
102
103#endif // __AITHETA2_INDEX_MEASURE_H__
104