1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 * \author guonix
17 * \date Mar 2021
18 * \brief
19 */
20
21#pragma once
22
23#include <ailego/encoding/json.h>
24#include <ailego/utility/float_helper.h>
25#include "error_code.h"
26#include "logger.h"
27#include "types.h"
28
29namespace proxima {
30namespace be {
31
32/*! Transformer
33 */
34class Transformer {
35 public:
36 //! Check transform action
37 static bool NeedTransform(DataTypes in, DataTypes out);
38
39 //! Check support current transform
40 static int SupportTransform(DataTypes in_type, DataTypes out_type);
41
42 //! Transform input features to output features
43 // return 0 for success otherwise failed
44 static int Transform(DataTypes, const std::string &, DataTypes,
45 std::string *);
46
47 //! Transform json array to std::vector, return count of element has been
48 // transformed
49 template <typename T>
50 static size_t Transform(const ailego::JsonArray &array,
51 std::vector<T> *values);
52
53 //! Interpret json value to vector, param matrix should contains valid
54 // matrix or array object.
55 template <typename T>
56 static size_t Transform(const ailego::JsonValue &matrix,
57 std::vector<T> *values);
58
59 //! Interpret json string to vector, param json should contains valid
60 // matrix or array object.
61 template <typename T>
62 static size_t Transform(
63 const std::string &json,
64 std::function<int(const ailego::JsonValue &)> *validator,
65 std::vector<T> *values);
66
67 //! transform vector to bytes
68 template <typename T, DataTypes>
69 static size_t Transform(const std::vector<T> &values, std::string *bytes);
70};
71
72/*!
73 * Transform json object to primary data
74 */
75struct Json2Primary {
76 template <class T>
77 static T Primary(const ailego::JsonValue &object) {
78 return static_cast<T>(object.as_integer());
79 }
80};
81
82//! Transform double of JsonValue to float
83template <>
84float Json2Primary::Primary<float>(const ailego::JsonValue &object);
85
86//! Transform double of JsonValue
87template <>
88double Json2Primary::Primary<double>(const ailego::JsonValue &object);
89
90template <typename T>
91size_t Transformer::Transform(const ailego::JsonArray &array,
92 std::vector<T> *values) {
93 for (auto it = array.begin(); it != array.end(); ++it) {
94 values->emplace_back(Json2Primary::Primary<T>(*it));
95 }
96 return size_t(array.size());
97}
98
99template <typename T>
100size_t Transformer::Transform(const ailego::JsonValue &matrix,
101 std::vector<T> *values) {
102 try {
103 auto &array = matrix.as_array();
104 size_t size = 0;
105 if (!array.empty() && array.begin()->is_array()) {
106 for (auto it = array.begin(); it != array.end(); ++it) {
107 size_t temp = Transformer::Transform(it->as_array(), values);
108 if (temp < 0) {
109 return temp;
110 }
111 size += temp;
112 }
113 } else {
114 size = Transformer::Transform(matrix.as_array(), values);
115 }
116 return size;
117 } catch (const std::exception &e) {
118 return PROXIMA_BE_ERROR_CODE(InvalidVectorFormat);
119 }
120}
121
122template <typename T>
123size_t Transformer::Transform(
124 const std::string &json,
125 std::function<int(const ailego::JsonValue &)> *validator,
126 std::vector<T> *values) {
127 ailego::JsonValue node;
128 if (!node.parse(json.c_str())) {
129 LOG_ERROR("Parse index json value failed.");
130 return PROXIMA_BE_ERROR_CODE(InvalidVectorFormat);
131 }
132
133 int error_code = !validator ? 0 : (*validator)(node);
134 if (error_code != 0) {
135 return size_t(error_code);
136 }
137 values->clear();
138 return Transformer::Transform(node, values);
139}
140
141/*!
142 * Serialize primary array into bytes
143 */
144struct Primary2Bytes {
145 template <class T, DataTypes>
146 static void Bytes(const std::vector<T> &values, std::string *bytes) {
147 auto capacity = values.size() * sizeof(T);
148 bytes->resize(capacity);
149 std::memcpy(&((*bytes)[0]), values.data(), capacity);
150 }
151};
152
153template <>
154void Primary2Bytes::Bytes<int8_t, DataTypes::VECTOR_INT4>(
155 const std::vector<int8_t> &values, std::string *bytes);
156
157template <>
158void Primary2Bytes::Bytes<float, DataTypes::VECTOR_FP16>(
159 const std::vector<float> &values, std::string *bytes);
160
161template <typename T, DataTypes DT>
162size_t Transformer::Transform(const std::vector<T> &values,
163 std::string *bytes) {
164 Primary2Bytes::Bytes<T, DT>(values, bytes);
165 return values.size();
166}
167
168} // namespace be
169} // namespace proxima
170