1 | /** |
2 | * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | * |
16 | * \author guonix |
17 | * \date Mar 2021 |
18 | * \brief |
19 | */ |
20 | |
21 | #pragma once |
22 | |
23 | #include <ailego/encoding/json.h> |
24 | #include <ailego/utility/float_helper.h> |
25 | #include "error_code.h" |
26 | #include "logger.h" |
27 | #include "types.h" |
28 | |
29 | namespace proxima { |
30 | namespace be { |
31 | |
32 | /*! Transformer |
33 | */ |
34 | class Transformer { |
35 | public: |
36 | //! Check transform action |
37 | static bool NeedTransform(DataTypes in, DataTypes out); |
38 | |
39 | //! Check support current transform |
40 | static int SupportTransform(DataTypes in_type, DataTypes out_type); |
41 | |
42 | //! Transform input features to output features |
43 | // return 0 for success otherwise failed |
44 | static int Transform(DataTypes, const std::string &, DataTypes, |
45 | std::string *); |
46 | |
47 | //! Transform json array to std::vector, return count of element has been |
48 | // transformed |
49 | template <typename T> |
50 | static size_t Transform(const ailego::JsonArray &array, |
51 | std::vector<T> *values); |
52 | |
53 | //! Interpret json value to vector, param matrix should contains valid |
54 | // matrix or array object. |
55 | template <typename T> |
56 | static size_t Transform(const ailego::JsonValue &matrix, |
57 | std::vector<T> *values); |
58 | |
59 | //! Interpret json string to vector, param json should contains valid |
60 | // matrix or array object. |
61 | template <typename T> |
62 | static size_t Transform( |
63 | const std::string &json, |
64 | std::function<int(const ailego::JsonValue &)> *validator, |
65 | std::vector<T> *values); |
66 | |
67 | //! transform vector to bytes |
68 | template <typename T, DataTypes> |
69 | static size_t Transform(const std::vector<T> &values, std::string *bytes); |
70 | }; |
71 | |
72 | /*! |
73 | * Transform json object to primary data |
74 | */ |
75 | struct Json2Primary { |
76 | template <class T> |
77 | static T Primary(const ailego::JsonValue &object) { |
78 | return static_cast<T>(object.as_integer()); |
79 | } |
80 | }; |
81 | |
82 | //! Transform double of JsonValue to float |
83 | template <> |
84 | float Json2Primary::Primary<float>(const ailego::JsonValue &object); |
85 | |
86 | //! Transform double of JsonValue |
87 | template <> |
88 | double Json2Primary::Primary<double>(const ailego::JsonValue &object); |
89 | |
90 | template <typename T> |
91 | size_t Transformer::Transform(const ailego::JsonArray &array, |
92 | std::vector<T> *values) { |
93 | for (auto it = array.begin(); it != array.end(); ++it) { |
94 | values->emplace_back(Json2Primary::Primary<T>(*it)); |
95 | } |
96 | return size_t(array.size()); |
97 | } |
98 | |
99 | template <typename T> |
100 | size_t Transformer::Transform(const ailego::JsonValue &matrix, |
101 | std::vector<T> *values) { |
102 | try { |
103 | auto &array = matrix.as_array(); |
104 | size_t size = 0; |
105 | if (!array.empty() && array.begin()->is_array()) { |
106 | for (auto it = array.begin(); it != array.end(); ++it) { |
107 | size_t temp = Transformer::Transform(it->as_array(), values); |
108 | if (temp < 0) { |
109 | return temp; |
110 | } |
111 | size += temp; |
112 | } |
113 | } else { |
114 | size = Transformer::Transform(matrix.as_array(), values); |
115 | } |
116 | return size; |
117 | } catch (const std::exception &e) { |
118 | return PROXIMA_BE_ERROR_CODE(InvalidVectorFormat); |
119 | } |
120 | } |
121 | |
122 | template <typename T> |
123 | size_t Transformer::Transform( |
124 | const std::string &json, |
125 | std::function<int(const ailego::JsonValue &)> *validator, |
126 | std::vector<T> *values) { |
127 | ailego::JsonValue node; |
128 | if (!node.parse(json.c_str())) { |
129 | LOG_ERROR("Parse index json value failed." ); |
130 | return PROXIMA_BE_ERROR_CODE(InvalidVectorFormat); |
131 | } |
132 | |
133 | int error_code = !validator ? 0 : (*validator)(node); |
134 | if (error_code != 0) { |
135 | return size_t(error_code); |
136 | } |
137 | values->clear(); |
138 | return Transformer::Transform(node, values); |
139 | } |
140 | |
141 | /*! |
142 | * Serialize primary array into bytes |
143 | */ |
144 | struct Primary2Bytes { |
145 | template <class T, DataTypes> |
146 | static void Bytes(const std::vector<T> &values, std::string *bytes) { |
147 | auto capacity = values.size() * sizeof(T); |
148 | bytes->resize(capacity); |
149 | std::memcpy(&((*bytes)[0]), values.data(), capacity); |
150 | } |
151 | }; |
152 | |
153 | template <> |
154 | void Primary2Bytes::Bytes<int8_t, DataTypes::VECTOR_INT4>( |
155 | const std::vector<int8_t> &values, std::string *bytes); |
156 | |
157 | template <> |
158 | void Primary2Bytes::Bytes<float, DataTypes::VECTOR_FP16>( |
159 | const std::vector<float> &values, std::string *bytes); |
160 | |
161 | template <typename T, DataTypes DT> |
162 | size_t Transformer::Transform(const std::vector<T> &values, |
163 | std::string *bytes) { |
164 | Primary2Bytes::Bytes<T, DT>(values, bytes); |
165 | return values.size(); |
166 | } |
167 | |
168 | } // namespace be |
169 | } // namespace proxima |
170 | |