1/**
2 * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15
16 * \author Haichao.chc
17 * \date Jun 2021
18 * \brief Index builder tool can create collection index
19 * from text or vecs file
20 */
21
22#include <fstream>
23#include <iostream>
24#include <gflags/gflags.h>
25#include "common/logger.h"
26#include "common/protobuf_helper.h"
27#include "common/types.h"
28#include "common/version.h"
29#include "index/collection.h"
30#include "index/typedef.h"
31#include "meta/meta.h"
32#include "proto/proxima_be.pb.h"
33#include "vecs_reader.h"
34
35using namespace proxima::be;
36
37DEFINE_string(schema, "", "Specify the schema of collection");
38DEFINE_string(file, "", "Specify input data file");
39DEFINE_string(output, "./", "Sepecify output index directory");
40DEFINE_uint32(concurrency, 10, "Threads count for building index");
41
42static bool ValidateNotEmpty(const char *flagname, const std::string &value) {
43 return !value.empty();
44}
45
46DEFINE_validator(schema, ValidateNotEmpty);
47DEFINE_validator(file, ValidateNotEmpty);
48
49struct Record {
50 uint64_t key;
51 std::string vector;
52 std::string attributes;
53 uint32_t dimension;
54};
55
56meta::CollectionMetaPtr g_collection_meta;
57
58static inline void PrintUsage() {
59 std::cout << "Usage:" << std::endl;
60 std::cout << " index_builder <args>" << std::endl << std::endl;
61 std::cout << "Args: " << std::endl;
62 std::cout << " --schema Specify the schema of collection"
63 << std::endl;
64 std::cout << " --file Specify input data file" << std::endl;
65 std::cout << " --output Sepecify output index directory(default ./)"
66 << std::endl;
67 std::cout << " --concurrency Sepecify threads count for building "
68 "index(default 10)"
69 << std::endl;
70 std::cout << " --help, -h Dipslay help info" << std::endl;
71 std::cout << " --version, -v Dipslay version info" << std::endl;
72}
73
74static bool ParseSchema() {
75 // Parse protobuf object from input json
76 ProtobufHelper::JsonParseOptions options;
77 options.ignore_unknown_fields = true;
78
79 proto::CollectionConfig collection_config;
80 if (!ProtobufHelper::JsonToMessage(FLAGS_schema, &collection_config)) {
81 LOG_ERROR("JsonToMessage failed. schema[%s]", FLAGS_schema.c_str());
82 return false;
83 }
84
85 std::string converted_json;
86 ProtobufHelper::MessageToJson(collection_config, &converted_json);
87
88 // Check input schema format
89 if (collection_config.collection_name().empty()) {
90 LOG_ERROR("Collection name can't be empty. schema[%s]",
91 converted_json.c_str());
92 return false;
93 }
94
95 if (collection_config.index_column_params_size() != 1) {
96 LOG_ERROR("Schema must contain an index column. schema[%s]",
97 converted_json.c_str());
98 return false;
99 }
100
101 auto *index_column_schema = collection_config.mutable_index_column_params(0);
102 if (index_column_schema->column_name().empty()) {
103 LOG_ERROR("Schema index column name can't be empty. schema[%s]",
104 converted_json.c_str());
105 return false;
106 }
107
108 if (index_column_schema->index_type() == proto::IndexType::IT_UNDEFINED) {
109 index_column_schema->set_index_type(
110 proto::IndexType::IT_PROXIMA_GRAPH_INDEX);
111 }
112
113 if (index_column_schema->data_type() == proto::DataType::DT_UNDEFINED) {
114 index_column_schema->set_data_type(proto::DataType::DT_VECTOR_FP32);
115 }
116
117 if (index_column_schema->dimension() == 0U) {
118 LOG_ERROR("Schema index column dimension must be set. schema[%s]",
119 converted_json.c_str());
120 return false;
121 }
122
123 if (collection_config.forward_column_names_size() > 1) {
124 LOG_ERROR("Schema can contain a forward column at most. schema[%s]",
125 converted_json.c_str());
126 return false;
127 }
128
129 // Generate collection meta from schema
130 g_collection_meta = std::make_shared<meta::CollectionMeta>();
131 g_collection_meta->set_name(collection_config.collection_name());
132
133 // Set forward column
134 if (collection_config.forward_column_names_size() > 0) {
135 g_collection_meta->mutable_forward_columns()->emplace_back(
136 collection_config.forward_column_names(0));
137 }
138
139 // Set index column
140 auto &column_param = collection_config.index_column_params(0);
141 auto new_column_meta = std::make_shared<meta::ColumnMeta>();
142 new_column_meta->set_name(column_param.column_name());
143 new_column_meta->set_index_type((IndexTypes)column_param.index_type());
144 new_column_meta->set_data_type((DataTypes)column_param.data_type());
145 new_column_meta->set_dimension(column_param.dimension());
146 for (int j = 0; j < column_param.extra_params_size(); j++) {
147 auto &kvpair = column_param.extra_params(j);
148 new_column_meta->mutable_parameters()->set(kvpair.key(), kvpair.value());
149 }
150 g_collection_meta->append(new_column_meta);
151
152 std::cout << "Parse collection schema success. schema[" << FLAGS_schema << "]"
153 << std::endl;
154 return true;
155}
156
157static void DoInsertCollection(index::Collection *collection,
158 const Record &record) {
159 index::CollectionDataset records(0);
160 auto *row = records.add_row_data();
161 row->operation_type = OperationTypes::INSERT;
162 row->primary_key = record.key;
163
164 // Serialize forward data to pb format
165 if (g_collection_meta->forward_columns().size() > 0) {
166 proto::GenericValueList value_list;
167 auto *value = value_list.add_values();
168 value->set_string_value(record.attributes);
169 value_list.SerializeToString(&row->forward_data);
170 }
171
172 // Append index column data
173 auto &index_column_schema = g_collection_meta->index_columns().at(0);
174 index::ColumnData index_column;
175 index_column.column_name = index_column_schema->name();
176 index_column.data_type = (DataTypes)index_column_schema->data_type();
177 index_column.dimension = record.dimension;
178 index_column.data = record.vector;
179 row->column_datas.emplace_back(index_column);
180
181 collection->write_records(records);
182}
183
184static bool LoadFromVecsFile(aitheta2::IndexThreads::TaskGroup *group,
185 index::Collection *collection) {
186 tools::VecsReader reader;
187 if (!reader.load(FLAGS_file)) {
188 LOG_ERROR("Load vecs file failed.");
189 return false;
190 }
191
192 for (uint32_t i = 0; i < reader.num_vecs(); i++) {
193 Record record;
194 uint64_t key = reader.get_key(i);
195 const char *feature = (const char *)reader.get_vector(i);
196
197 record.key = key;
198 record.vector.append(feature, reader.index_meta().element_size());
199 record.dimension = reader.index_meta().dimension();
200 group->submit(ailego::Closure::New(DoInsertCollection, collection, record));
201 }
202 return true;
203}
204
205static bool LoadFromTextFile(aitheta2::IndexThreads::TaskGroup *group,
206 index::Collection *collection) {
207 std::ifstream file_stream(FLAGS_file);
208 if (!file_stream.is_open()) {
209 LOG_ERROR("Can't open input file[%s]", FLAGS_file.c_str());
210 return false;
211 }
212
213 std::string line;
214 while (std::getline(file_stream, line)) {
215 line.erase(line.find_last_not_of('\n') + 1);
216 if (line.empty()) {
217 continue;
218 }
219 std::vector<std::string> res;
220 ailego::StringHelper::Split(line, ';', &res);
221 if (res.size() < 2) {
222 LOG_ERROR("Bad input line, format[key;vector(1 2 3 4...);attributes]");
223 continue;
224 }
225
226 Record record;
227 // Parse key
228 uint64_t key = std::stoull(res[0]);
229 record.key = key;
230 // Parse feature
231 if (res.size() >= 2) {
232 auto data_type = g_collection_meta->index_columns().at(0)->data_type();
233 if (data_type == DataTypes::VECTOR_BINARY32) {
234 std::vector<uint8_t> vec;
235 ailego::StringHelper::Split(res[1], ' ', &vec);
236 if (vec.size() == 0 || vec.size() % 32 != 0) {
237 LOG_ERROR("Bad feature field");
238 continue;
239 }
240
241 std::vector<uint8_t> tmp;
242 for (size_t i = 0; i < vec.size(); i += 8) {
243 uint8_t v = 0;
244 v |= (vec[i] & 0x01) << 7;
245 v |= (vec[i + 1] & 0x01) << 6;
246 v |= (vec[i + 2] & 0x01) << 5;
247 v |= (vec[i + 3] & 0x01) << 4;
248 v |= (vec[i + 4] & 0x01) << 3;
249 v |= (vec[i + 5] & 0x01) << 2;
250 v |= (vec[i + 6] & 0x01) << 1;
251 v |= (vec[i + 7] & 0x01) << 0;
252 tmp.push_back(v);
253 }
254
255 record.vector =
256 std::string((const char *)tmp.data(), tmp.size() * sizeof(uint8_t));
257 record.dimension = vec.size();
258 } else {
259 if (data_type == DataTypes::VECTOR_FP32) {
260 std::vector<float> feature;
261 ailego::StringHelper::Split(res[1], ' ', &feature);
262 if (feature.size() == 0) {
263 LOG_ERROR("Bad feature field");
264 continue;
265 }
266 record.vector = std::string((const char *)feature.data(),
267 feature.size() * sizeof(float));
268 record.dimension = feature.size();
269 } else if (data_type == DataTypes::VECTOR_INT8) {
270 std::vector<int8_t> feature;
271 ailego::StringHelper::Split(res[1], ' ', &feature);
272 if (feature.size() == 0) {
273 LOG_ERROR("Bad feature field");
274 continue;
275 }
276 record.vector = std::string((const char *)feature.data(),
277 feature.size() * sizeof(int8_t));
278 record.dimension = feature.size();
279 }
280 }
281 }
282 // Parse attributes
283 if (res.size() >= 3) {
284 record.attributes = res[2];
285 }
286
287 group->submit(ailego::Closure::New(DoInsertCollection, collection, record));
288 }
289 file_stream.close();
290
291 return true;
292}
293
294static bool BuildIndex() {
295 index::ThreadPool thread_pool(FLAGS_concurrency, false);
296
297 // Create and open new collection
298 index::CollectionPtr new_collection;
299 index::ReadOptions read_options;
300 read_options.use_mmap = true;
301 read_options.create_new = true;
302 int ret = index::Collection::CreateAndOpen(
303 g_collection_meta->name(), FLAGS_output, g_collection_meta,
304 FLAGS_concurrency, &thread_pool, read_options, &new_collection);
305 if (ret != 0) {
306 return false;
307 }
308 std::cout << "Create collection complete. collection["
309 << g_collection_meta->name() << "]" << std::endl;
310
311 // Writing into collection in parallel
312 auto group = thread_pool.make_group();
313
314 if (FLAGS_file.find(".vecs") != std::string::npos) {
315 if (!LoadFromVecsFile(group.get(), new_collection.get())) {
316 return false;
317 }
318 } else {
319 if (!LoadFromTextFile(group.get(), new_collection.get())) {
320 return false;
321 }
322 }
323
324 group->wait_finish();
325 std::cout << "Build index complete. collection[" << g_collection_meta->name()
326 << "]" << std::endl;
327
328 // Dump to disk
329 ret = new_collection->dump();
330 if (ret != 0) {
331 return false;
332 }
333
334 ret = new_collection->close();
335 if (ret != 0) {
336 return false;
337 }
338 std::cout << "Dump index complete. collection[" << g_collection_meta->name()
339 << "]" << std::endl;
340
341 return true;
342}
343
344int main(int argc, char **argv) {
345 // Parse arguments
346 for (int i = 1; i < argc; ++i) {
347 const char *arg = argv[i];
348 if (!strcmp(arg, "-help") || !strcmp(arg, "--help") || !strcmp(arg, "-h")) {
349 PrintUsage();
350 exit(0);
351 } else if (!strcmp(arg, "-version") || !strcmp(arg, "--version") ||
352 !strcmp(arg, "-v")) {
353 std::cout << proxima::be::Version::Details() << std::endl;
354 exit(0);
355 }
356 }
357 gflags::ParseCommandLineNonHelpFlags(&argc, &argv, false);
358
359 // Adjust log level to prevent print too many logs
360 aitheta2::IndexLoggerBroker::SetLevel(aitheta2::IndexLogger::LEVEL_WARN);
361
362 // Parse schema
363 if (!ParseSchema()) {
364 LOG_ERROR("Parse schema failed.");
365 exit(1);
366 }
367
368 // Start to build index
369 if (!BuildIndex()) {
370 LOG_ERROR("Build index error.");
371 exit(1);
372 }
373}
374