1 | /** |
2 | * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | * |
16 | * \author guonix |
17 | * \date Apr 2021 |
18 | * \brief |
19 | */ |
20 | |
21 | #define private public |
22 | #include "common/transformer.h" |
23 | #undef private |
24 | #include <gtest/gtest.h> |
25 | |
26 | using namespace proxima::be; |
27 | |
28 | class TransformerTest : public testing::Test { |
29 | protected: |
30 | void SetUp() {} |
31 | |
32 | void TearDown() {} |
33 | }; |
34 | |
35 | TEST_F(TransformerTest, TestTransformJsonVector) { |
36 | size_t dimension = 512; |
37 | std::ostringstream oss; |
38 | oss << "[" ; |
39 | for (size_t i = 0; i < 512; ++i) { |
40 | if (i % 2 == 0) |
41 | oss << (i + 1) / 512.0 << "," ; |
42 | else |
43 | oss << (i + 1) / (-512.0) << "," ; |
44 | } |
45 | oss << "]" ; |
46 | std::string index_value = oss.str(); |
47 | std::string output_value; |
48 | std::vector<float> vectors; |
49 | auto ret = Transformer::Transform(index_value, nullptr, &vectors); |
50 | ASSERT_EQ(ret, dimension); |
51 | ASSERT_EQ(vectors.size(), dimension); |
52 | } |
53 | |
54 | TEST_F(TransformerTest, TestExpectedSize) { |
55 | std::ostringstream oss; |
56 | oss << "[" ; |
57 | for (size_t i = 0; i < 512; ++i) { |
58 | if (i % 2 == 0) |
59 | oss << (i + 1) / 512.0 << "," ; |
60 | else |
61 | oss << (i + 1) / (-512.0) << "," ; |
62 | } |
63 | oss << "]" ; |
64 | std::string index_value = oss.str(); |
65 | std::string output_value; |
66 | std::vector<float> vectors; |
67 | auto ret = Transformer::Transform(index_value, nullptr, &vectors); |
68 | ASSERT_EQ(ret, 512); |
69 | ASSERT_EQ(vectors.size(), 512); |
70 | } |
71 | |
72 | TEST_F(TransformerTest, TestInvalidVectorFormat) { |
73 | std::ostringstream oss; |
74 | oss << "{\"a\":1}" ; |
75 | std::string index_value = oss.str(); |
76 | std::string output_value; |
77 | std::vector<float> vectors; |
78 | auto ret = Transformer::Transform(index_value, nullptr, &vectors); |
79 | ASSERT_EQ(ret, PROXIMA_BE_ERROR_CODE(InvalidVectorFormat)); |
80 | } |
81 | |
82 | TEST_F(TransformerTest, TestParseJsonVectorFailedWithInvalidType) { |
83 | std::ostringstream oss; |
84 | oss << "[" ; |
85 | for (size_t i = 0; i < 512; ++i) { |
86 | if (i % 2 == 0) |
87 | oss << (i + 1) / 512.0 << "," ; |
88 | else |
89 | oss << (i + 1) / (-512.0) << "," ; |
90 | } |
91 | std::string index_value = oss.str(); |
92 | std::string output_value; |
93 | std::vector<float> vectors; |
94 | auto ret = Transformer::Transform(index_value, nullptr, &vectors); |
95 | |
96 | ASSERT_EQ(ret, PROXIMA_BE_ERROR_CODE(InvalidVectorFormat)); |
97 | } |
98 | |
99 | |
100 | TEST_F(TransformerTest, TestInt82Int4) { |
101 | std::string index_value("[1,2,3,4,5,6]" ); |
102 | std::string output_value; |
103 | std::vector<int8_t> values; |
104 | Transformer::Transform(index_value, nullptr, &values); |
105 | Primary2Bytes::Bytes<int8_t, DataTypes::VECTOR_INT4>(values, &output_value); |
106 | const uint8_t *data = (const uint8_t *)(&(output_value[0])); |
107 | for (uint32_t i = 1; i <= 3; ++i) { |
108 | ASSERT_FLOAT_EQ((int8_t)(2 * i - 1), (int8_t)(data[i - 1] & 0xf)); |
109 | ASSERT_FLOAT_EQ((uint8_t)(2 * i), (int8_t)(data[i - 1] >> 4)); |
110 | } |
111 | } |
112 | |
113 | TEST_F(TransformerTest, TestFP32ToFP16) { |
114 | std::string index_value("[1,2,3,4,5,6]" ); |
115 | uint32_t dimension = 6; |
116 | std::string output_value; |
117 | std::vector<float> values; |
118 | Transformer::Transform(index_value, nullptr, &values); |
119 | Primary2Bytes::Bytes<float, DataTypes::VECTOR_FP16>(values, &output_value); |
120 | const uint16_t *data = (const uint16_t *)(&(output_value[0])); |
121 | for (uint32_t i = 1; i <= dimension; ++i) { |
122 | ASSERT_FLOAT_EQ(1.0f * i, ailego::FloatHelper::ToFP32(data[i - 1])); |
123 | } |
124 | } |