1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * \author guonix
17 * \date Apr 2021
18 * \brief
19 */
20
21#define private public
22#include "common/transformer.h"
23#undef private
24#include <gtest/gtest.h>
25
26using namespace proxima::be;
27
28class TransformerTest : public testing::Test {
29 protected:
30 void SetUp() {}
31
32 void TearDown() {}
33};
34
35TEST_F(TransformerTest, TestTransformJsonVector) {
36 size_t dimension = 512;
37 std::ostringstream oss;
38 oss << "[";
39 for (size_t i = 0; i < 512; ++i) {
40 if (i % 2 == 0)
41 oss << (i + 1) / 512.0 << ",";
42 else
43 oss << (i + 1) / (-512.0) << ",";
44 }
45 oss << "]";
46 std::string index_value = oss.str();
47 std::string output_value;
48 std::vector<float> vectors;
49 auto ret = Transformer::Transform(index_value, nullptr, &vectors);
50 ASSERT_EQ(ret, dimension);
51 ASSERT_EQ(vectors.size(), dimension);
52}
53
54TEST_F(TransformerTest, TestExpectedSize) {
55 std::ostringstream oss;
56 oss << "[";
57 for (size_t i = 0; i < 512; ++i) {
58 if (i % 2 == 0)
59 oss << (i + 1) / 512.0 << ",";
60 else
61 oss << (i + 1) / (-512.0) << ",";
62 }
63 oss << "]";
64 std::string index_value = oss.str();
65 std::string output_value;
66 std::vector<float> vectors;
67 auto ret = Transformer::Transform(index_value, nullptr, &vectors);
68 ASSERT_EQ(ret, 512);
69 ASSERT_EQ(vectors.size(), 512);
70}
71
72TEST_F(TransformerTest, TestInvalidVectorFormat) {
73 std::ostringstream oss;
74 oss << "{\"a\":1}";
75 std::string index_value = oss.str();
76 std::string output_value;
77 std::vector<float> vectors;
78 auto ret = Transformer::Transform(index_value, nullptr, &vectors);
79 ASSERT_EQ(ret, PROXIMA_BE_ERROR_CODE(InvalidVectorFormat));
80}
81
82TEST_F(TransformerTest, TestParseJsonVectorFailedWithInvalidType) {
83 std::ostringstream oss;
84 oss << "[";
85 for (size_t i = 0; i < 512; ++i) {
86 if (i % 2 == 0)
87 oss << (i + 1) / 512.0 << ",";
88 else
89 oss << (i + 1) / (-512.0) << ",";
90 }
91 std::string index_value = oss.str();
92 std::string output_value;
93 std::vector<float> vectors;
94 auto ret = Transformer::Transform(index_value, nullptr, &vectors);
95
96 ASSERT_EQ(ret, PROXIMA_BE_ERROR_CODE(InvalidVectorFormat));
97}
98
99
100TEST_F(TransformerTest, TestInt82Int4) {
101 std::string index_value("[1,2,3,4,5,6]");
102 std::string output_value;
103 std::vector<int8_t> values;
104 Transformer::Transform(index_value, nullptr, &values);
105 Primary2Bytes::Bytes<int8_t, DataTypes::VECTOR_INT4>(values, &output_value);
106 const uint8_t *data = (const uint8_t *)(&(output_value[0]));
107 for (uint32_t i = 1; i <= 3; ++i) {
108 ASSERT_FLOAT_EQ((int8_t)(2 * i - 1), (int8_t)(data[i - 1] & 0xf));
109 ASSERT_FLOAT_EQ((uint8_t)(2 * i), (int8_t)(data[i - 1] >> 4));
110 }
111}
112
113TEST_F(TransformerTest, TestFP32ToFP16) {
114 std::string index_value("[1,2,3,4,5,6]");
115 uint32_t dimension = 6;
116 std::string output_value;
117 std::vector<float> values;
118 Transformer::Transform(index_value, nullptr, &values);
119 Primary2Bytes::Bytes<float, DataTypes::VECTOR_FP16>(values, &output_value);
120 const uint16_t *data = (const uint16_t *)(&(output_value[0]));
121 for (uint32_t i = 1; i <= dimension; ++i) {
122 ASSERT_FLOAT_EQ(1.0f * i, ailego::FloatHelper::ToFP32(data[i - 1]));
123 }
124}