1 | /** |
2 | * Copyright 2021 Alibaba, Inc. and its affiliates. All Rights Reserved. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | |
16 | * \author Haichao.chc |
17 | * \date Jun 2021 |
18 | * \brief Index builder tool can create collection index |
19 | * from text or vecs file |
20 | */ |
21 | |
22 | #include <fstream> |
23 | #include <iostream> |
24 | #include <gflags/gflags.h> |
25 | #include "common/logger.h" |
26 | #include "common/protobuf_helper.h" |
27 | #include "common/types.h" |
28 | #include "common/version.h" |
29 | #include "index/collection.h" |
30 | #include "index/typedef.h" |
31 | #include "meta/meta.h" |
32 | #include "proto/proxima_be.pb.h" |
33 | #include "vecs_reader.h" |
34 | |
35 | using namespace proxima::be; |
36 | |
37 | DEFINE_string(schema, "" , "Specify the schema of collection" ); |
38 | DEFINE_string(file, "" , "Specify input data file" ); |
39 | DEFINE_string(output, "./" , "Sepecify output index directory" ); |
40 | DEFINE_uint32(concurrency, 10, "Threads count for building index" ); |
41 | |
42 | static bool ValidateNotEmpty(const char *flagname, const std::string &value) { |
43 | return !value.empty(); |
44 | } |
45 | |
46 | DEFINE_validator(schema, ValidateNotEmpty); |
47 | DEFINE_validator(file, ValidateNotEmpty); |
48 | |
49 | struct Record { |
50 | uint64_t key; |
51 | std::string vector; |
52 | std::string attributes; |
53 | uint32_t dimension; |
54 | }; |
55 | |
56 | meta::CollectionMetaPtr g_collection_meta; |
57 | |
58 | static inline void PrintUsage() { |
59 | std::cout << "Usage:" << std::endl; |
60 | std::cout << " index_builder <args>" << std::endl << std::endl; |
61 | std::cout << "Args: " << std::endl; |
62 | std::cout << " --schema Specify the schema of collection" |
63 | << std::endl; |
64 | std::cout << " --file Specify input data file" << std::endl; |
65 | std::cout << " --output Sepecify output index directory(default ./)" |
66 | << std::endl; |
67 | std::cout << " --concurrency Sepecify threads count for building " |
68 | "index(default 10)" |
69 | << std::endl; |
70 | std::cout << " --help, -h Dipslay help info" << std::endl; |
71 | std::cout << " --version, -v Dipslay version info" << std::endl; |
72 | } |
73 | |
74 | static bool ParseSchema() { |
75 | // Parse protobuf object from input json |
76 | ProtobufHelper::JsonParseOptions options; |
77 | options.ignore_unknown_fields = true; |
78 | |
79 | proto::CollectionConfig collection_config; |
80 | if (!ProtobufHelper::JsonToMessage(FLAGS_schema, &collection_config)) { |
81 | LOG_ERROR("JsonToMessage failed. schema[%s]" , FLAGS_schema.c_str()); |
82 | return false; |
83 | } |
84 | |
85 | std::string converted_json; |
86 | ProtobufHelper::MessageToJson(collection_config, &converted_json); |
87 | |
88 | // Check input schema format |
89 | if (collection_config.collection_name().empty()) { |
90 | LOG_ERROR("Collection name can't be empty. schema[%s]" , |
91 | converted_json.c_str()); |
92 | return false; |
93 | } |
94 | |
95 | if (collection_config.index_column_params_size() != 1) { |
96 | LOG_ERROR("Schema must contain an index column. schema[%s]" , |
97 | converted_json.c_str()); |
98 | return false; |
99 | } |
100 | |
101 | auto *index_column_schema = collection_config.mutable_index_column_params(0); |
102 | if (index_column_schema->column_name().empty()) { |
103 | LOG_ERROR("Schema index column name can't be empty. schema[%s]" , |
104 | converted_json.c_str()); |
105 | return false; |
106 | } |
107 | |
108 | if (index_column_schema->index_type() == proto::IndexType::IT_UNDEFINED) { |
109 | index_column_schema->set_index_type( |
110 | proto::IndexType::IT_PROXIMA_GRAPH_INDEX); |
111 | } |
112 | |
113 | if (index_column_schema->data_type() == proto::DataType::DT_UNDEFINED) { |
114 | index_column_schema->set_data_type(proto::DataType::DT_VECTOR_FP32); |
115 | } |
116 | |
117 | if (index_column_schema->dimension() == 0U) { |
118 | LOG_ERROR("Schema index column dimension must be set. schema[%s]" , |
119 | converted_json.c_str()); |
120 | return false; |
121 | } |
122 | |
123 | if (collection_config.forward_column_names_size() > 1) { |
124 | LOG_ERROR("Schema can contain a forward column at most. schema[%s]" , |
125 | converted_json.c_str()); |
126 | return false; |
127 | } |
128 | |
129 | // Generate collection meta from schema |
130 | g_collection_meta = std::make_shared<meta::CollectionMeta>(); |
131 | g_collection_meta->set_name(collection_config.collection_name()); |
132 | |
133 | // Set forward column |
134 | if (collection_config.forward_column_names_size() > 0) { |
135 | g_collection_meta->mutable_forward_columns()->emplace_back( |
136 | collection_config.forward_column_names(0)); |
137 | } |
138 | |
139 | // Set index column |
140 | auto &column_param = collection_config.index_column_params(0); |
141 | auto new_column_meta = std::make_shared<meta::ColumnMeta>(); |
142 | new_column_meta->set_name(column_param.column_name()); |
143 | new_column_meta->set_index_type((IndexTypes)column_param.index_type()); |
144 | new_column_meta->set_data_type((DataTypes)column_param.data_type()); |
145 | new_column_meta->set_dimension(column_param.dimension()); |
146 | for (int j = 0; j < column_param.extra_params_size(); j++) { |
147 | auto &kvpair = column_param.extra_params(j); |
148 | new_column_meta->mutable_parameters()->set(kvpair.key(), kvpair.value()); |
149 | } |
150 | g_collection_meta->append(new_column_meta); |
151 | |
152 | std::cout << "Parse collection schema success. schema[" << FLAGS_schema << "]" |
153 | << std::endl; |
154 | return true; |
155 | } |
156 | |
157 | static void DoInsertCollection(index::Collection *collection, |
158 | const Record &record) { |
159 | index::CollectionDataset records(0); |
160 | auto *row = records.add_row_data(); |
161 | row->operation_type = OperationTypes::INSERT; |
162 | row->primary_key = record.key; |
163 | |
164 | // Serialize forward data to pb format |
165 | if (g_collection_meta->forward_columns().size() > 0) { |
166 | proto::GenericValueList value_list; |
167 | auto *value = value_list.add_values(); |
168 | value->set_string_value(record.attributes); |
169 | value_list.SerializeToString(&row->forward_data); |
170 | } |
171 | |
172 | // Append index column data |
173 | auto &index_column_schema = g_collection_meta->index_columns().at(0); |
174 | index::ColumnData index_column; |
175 | index_column.column_name = index_column_schema->name(); |
176 | index_column.data_type = (DataTypes)index_column_schema->data_type(); |
177 | index_column.dimension = record.dimension; |
178 | index_column.data = record.vector; |
179 | row->column_datas.emplace_back(index_column); |
180 | |
181 | collection->write_records(records); |
182 | } |
183 | |
184 | static bool LoadFromVecsFile(aitheta2::IndexThreads::TaskGroup *group, |
185 | index::Collection *collection) { |
186 | tools::VecsReader reader; |
187 | if (!reader.load(FLAGS_file)) { |
188 | LOG_ERROR("Load vecs file failed." ); |
189 | return false; |
190 | } |
191 | |
192 | for (uint32_t i = 0; i < reader.num_vecs(); i++) { |
193 | Record record; |
194 | uint64_t key = reader.get_key(i); |
195 | const char *feature = (const char *)reader.get_vector(i); |
196 | |
197 | record.key = key; |
198 | record.vector.append(feature, reader.index_meta().element_size()); |
199 | record.dimension = reader.index_meta().dimension(); |
200 | group->submit(ailego::Closure::New(DoInsertCollection, collection, record)); |
201 | } |
202 | return true; |
203 | } |
204 | |
205 | static bool LoadFromTextFile(aitheta2::IndexThreads::TaskGroup *group, |
206 | index::Collection *collection) { |
207 | std::ifstream file_stream(FLAGS_file); |
208 | if (!file_stream.is_open()) { |
209 | LOG_ERROR("Can't open input file[%s]" , FLAGS_file.c_str()); |
210 | return false; |
211 | } |
212 | |
213 | std::string line; |
214 | while (std::getline(file_stream, line)) { |
215 | line.erase(line.find_last_not_of('\n') + 1); |
216 | if (line.empty()) { |
217 | continue; |
218 | } |
219 | std::vector<std::string> res; |
220 | ailego::StringHelper::Split(line, ';', &res); |
221 | if (res.size() < 2) { |
222 | LOG_ERROR("Bad input line, format[key;vector(1 2 3 4...);attributes]" ); |
223 | continue; |
224 | } |
225 | |
226 | Record record; |
227 | // Parse key |
228 | uint64_t key = std::stoull(res[0]); |
229 | record.key = key; |
230 | // Parse feature |
231 | if (res.size() >= 2) { |
232 | auto data_type = g_collection_meta->index_columns().at(0)->data_type(); |
233 | if (data_type == DataTypes::VECTOR_BINARY32) { |
234 | std::vector<uint8_t> vec; |
235 | ailego::StringHelper::Split(res[1], ' ', &vec); |
236 | if (vec.size() == 0 || vec.size() % 32 != 0) { |
237 | LOG_ERROR("Bad feature field" ); |
238 | continue; |
239 | } |
240 | |
241 | std::vector<uint8_t> tmp; |
242 | for (size_t i = 0; i < vec.size(); i += 8) { |
243 | uint8_t v = 0; |
244 | v |= (vec[i] & 0x01) << 7; |
245 | v |= (vec[i + 1] & 0x01) << 6; |
246 | v |= (vec[i + 2] & 0x01) << 5; |
247 | v |= (vec[i + 3] & 0x01) << 4; |
248 | v |= (vec[i + 4] & 0x01) << 3; |
249 | v |= (vec[i + 5] & 0x01) << 2; |
250 | v |= (vec[i + 6] & 0x01) << 1; |
251 | v |= (vec[i + 7] & 0x01) << 0; |
252 | tmp.push_back(v); |
253 | } |
254 | |
255 | record.vector = |
256 | std::string((const char *)tmp.data(), tmp.size() * sizeof(uint8_t)); |
257 | record.dimension = vec.size(); |
258 | } else { |
259 | if (data_type == DataTypes::VECTOR_FP32) { |
260 | std::vector<float> feature; |
261 | ailego::StringHelper::Split(res[1], ' ', &feature); |
262 | if (feature.size() == 0) { |
263 | LOG_ERROR("Bad feature field" ); |
264 | continue; |
265 | } |
266 | record.vector = std::string((const char *)feature.data(), |
267 | feature.size() * sizeof(float)); |
268 | record.dimension = feature.size(); |
269 | } else if (data_type == DataTypes::VECTOR_INT8) { |
270 | std::vector<int8_t> feature; |
271 | ailego::StringHelper::Split(res[1], ' ', &feature); |
272 | if (feature.size() == 0) { |
273 | LOG_ERROR("Bad feature field" ); |
274 | continue; |
275 | } |
276 | record.vector = std::string((const char *)feature.data(), |
277 | feature.size() * sizeof(int8_t)); |
278 | record.dimension = feature.size(); |
279 | } |
280 | } |
281 | } |
282 | // Parse attributes |
283 | if (res.size() >= 3) { |
284 | record.attributes = res[2]; |
285 | } |
286 | |
287 | group->submit(ailego::Closure::New(DoInsertCollection, collection, record)); |
288 | } |
289 | file_stream.close(); |
290 | |
291 | return true; |
292 | } |
293 | |
294 | static bool BuildIndex() { |
295 | index::ThreadPool thread_pool(FLAGS_concurrency, false); |
296 | |
297 | // Create and open new collection |
298 | index::CollectionPtr new_collection; |
299 | index::ReadOptions read_options; |
300 | read_options.use_mmap = true; |
301 | read_options.create_new = true; |
302 | int ret = index::Collection::CreateAndOpen( |
303 | g_collection_meta->name(), FLAGS_output, g_collection_meta, |
304 | FLAGS_concurrency, &thread_pool, read_options, &new_collection); |
305 | if (ret != 0) { |
306 | return false; |
307 | } |
308 | std::cout << "Create collection complete. collection[" |
309 | << g_collection_meta->name() << "]" << std::endl; |
310 | |
311 | // Writing into collection in parallel |
312 | auto group = thread_pool.make_group(); |
313 | |
314 | if (FLAGS_file.find(".vecs" ) != std::string::npos) { |
315 | if (!LoadFromVecsFile(group.get(), new_collection.get())) { |
316 | return false; |
317 | } |
318 | } else { |
319 | if (!LoadFromTextFile(group.get(), new_collection.get())) { |
320 | return false; |
321 | } |
322 | } |
323 | |
324 | group->wait_finish(); |
325 | std::cout << "Build index complete. collection[" << g_collection_meta->name() |
326 | << "]" << std::endl; |
327 | |
328 | // Dump to disk |
329 | ret = new_collection->dump(); |
330 | if (ret != 0) { |
331 | return false; |
332 | } |
333 | |
334 | ret = new_collection->close(); |
335 | if (ret != 0) { |
336 | return false; |
337 | } |
338 | std::cout << "Dump index complete. collection[" << g_collection_meta->name() |
339 | << "]" << std::endl; |
340 | |
341 | return true; |
342 | } |
343 | |
344 | int main(int argc, char **argv) { |
345 | // Parse arguments |
346 | for (int i = 1; i < argc; ++i) { |
347 | const char *arg = argv[i]; |
348 | if (!strcmp(arg, "-help" ) || !strcmp(arg, "--help" ) || !strcmp(arg, "-h" )) { |
349 | PrintUsage(); |
350 | exit(0); |
351 | } else if (!strcmp(arg, "-version" ) || !strcmp(arg, "--version" ) || |
352 | !strcmp(arg, "-v" )) { |
353 | std::cout << proxima::be::Version::Details() << std::endl; |
354 | exit(0); |
355 | } |
356 | } |
357 | gflags::ParseCommandLineNonHelpFlags(&argc, &argv, false); |
358 | |
359 | // Adjust log level to prevent print too many logs |
360 | aitheta2::IndexLoggerBroker::SetLevel(aitheta2::IndexLogger::LEVEL_WARN); |
361 | |
362 | // Parse schema |
363 | if (!ParseSchema()) { |
364 | LOG_ERROR("Parse schema failed." ); |
365 | exit(1); |
366 | } |
367 | |
368 | // Start to build index |
369 | if (!BuildIndex()) { |
370 | LOG_ERROR("Build index error." ); |
371 | exit(1); |
372 | } |
373 | } |
374 | |