1 | /******************************************************************************* |
2 | Copyright (c) The Taichi Authors (2016- ). All Rights Reserved. |
3 | The use of this software is governed by the LICENSE file. |
4 | *******************************************************************************/ |
5 | |
6 | #pragma once |
7 | |
8 | #include <array> |
9 | #include <cassert> |
10 | #include <cstring> |
11 | #include <fstream> |
12 | #include <iostream> |
13 | #include <map> |
14 | #include <memory> |
15 | #include <optional> |
16 | #include <sstream> |
17 | #include <string_view> |
18 | #include <type_traits> |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | #include "taichi/common/json.h" |
22 | #include "taichi/common/json_serde.h" |
23 | |
24 | #ifdef TI_INCLUDED |
25 | namespace taichi { |
26 | #else |
27 | #define TI_TRACE |
28 | #define TI_CRITICAL |
29 | #define TI_ASSERT assert |
30 | #endif |
31 | |
32 | template <typename T> |
33 | std::unique_ptr<T> create_instance_unique(const std::string &alias); |
34 | |
35 | //////////////////////////////////////////////////////////////////////////////// |
36 | // A Minimalist Serializer for Taichi // |
37 | // (Requires C++17) // |
38 | //////////////////////////////////////////////////////////////////////////////// |
39 | |
40 | // TODO: Consider using third-party serialization libraries |
41 | // * https://github.com/USCiLab/cereal |
42 | class Unit; |
43 | |
44 | namespace type { |
45 | |
46 | template <typename T> |
47 | using remove_cvref = |
48 | typename std::remove_cv<typename std::remove_reference<T>::type>; |
49 | |
50 | template <typename T> |
51 | using remove_cvref_t = typename remove_cvref<T>::type; |
52 | |
53 | template <typename T> |
54 | using is_unit = typename std::is_base_of<Unit, remove_cvref_t<T>>; |
55 | |
56 | template <typename T> |
57 | using is_unit_t = typename is_unit<T>::type; |
58 | |
59 | } // namespace type |
60 | class TextSerializer; |
61 | namespace detail { |
62 | |
63 | template <size_t N> |
64 | constexpr size_t count_delim(const char (&str)[N], char delim) { |
65 | size_t count = 1; |
66 | for (const char &ch : str) { |
67 | if (ch == delim) { |
68 | ++count; |
69 | } |
70 | } |
71 | return count; |
72 | } |
73 | |
74 | template <size_t DelimN> |
75 | struct StrDelimSplitter { |
76 | template <size_t StrsN> |
77 | static constexpr std::array<std::string_view, DelimN> make( |
78 | const char (&str)[StrsN], |
79 | char delim) { |
80 | std::array<std::string_view, DelimN> res; |
81 | const char *head = &(str[0]); |
82 | size_t si = 0; |
83 | size_t cur_head_i = 0; |
84 | size_t ri = 0; |
85 | while (si < StrsN) { |
86 | if (str[si] != delim) { |
87 | si += 1; |
88 | } else { |
89 | res[ri] = {head, (si - cur_head_i)}; |
90 | ++ri; |
91 | si += 2; // skip ", " |
92 | cur_head_i = si; |
93 | head = &(str[cur_head_i]); |
94 | } |
95 | } |
96 | // `StrsN - 1` because the last char is '\0'. |
97 | res[ri] = {head, (StrsN - 1 - cur_head_i)}; |
98 | return res; |
99 | } |
100 | }; |
101 | |
102 | template <typename SER, size_t N, typename T> |
103 | void serialize_kv_impl(SER &ser, |
104 | const std::array<std::string_view, N> &keys, |
105 | T &&val) { |
106 | std::string key{keys[N - 1]}; |
107 | ser(key.c_str(), val); |
108 | } |
109 | |
110 | template <typename SER, size_t N, typename T, typename... Args> |
111 | typename std::enable_if<!std::is_same<SER, TextSerializer>::value, void>::type |
112 | serialize_kv_impl(SER &ser, |
113 | const std::array<std::string_view, N> &keys, |
114 | T &&head, |
115 | Args &&...rest) { |
116 | constexpr auto i = (N - 1 - sizeof...(Args)); |
117 | std::string key{keys[i]}; |
118 | ser(key.c_str(), head); |
119 | serialize_kv_impl(ser, keys, rest...); |
120 | } |
121 | |
122 | // Specialize for TextSerializer since we need to append comma in the end for |
123 | // non-last object. |
124 | template <typename SER, size_t N, typename T, typename... Args> |
125 | typename std::enable_if<std::is_same<SER, TextSerializer>::value, void>::type |
126 | serialize_kv_impl(SER &ser, |
127 | const std::array<std::string_view, N> &keys, |
128 | T &&head, |
129 | Args &&...rest) { |
130 | constexpr auto i = (N - 1 - sizeof...(Args)); |
131 | std::string key{keys[i]}; |
132 | ser(key.c_str(), head, true); |
133 | serialize_kv_impl(ser, keys, rest...); |
134 | } |
135 | |
136 | } // namespace detail |
137 | |
138 | #define TI_IO_DECL \ |
139 | template <typename S> \ |
140 | void io(S &serializer) const |
141 | |
142 | #define TI_IO_DEF(...) \ |
143 | L_JSON_SERDE_FIELDS(__VA_ARGS__) \ |
144 | template <typename S> \ |
145 | void io(S &serializer) const { \ |
146 | TI_IO(__VA_ARGS__); \ |
147 | } |
148 | |
149 | // This macro serializes each field with its name by doing the following: |
150 | // 1. Stringifies __VA_ARGS__, then split the stringified result by ',' at |
151 | // compile time. |
152 | // 2. Invoke serializer::operator("arg", arg) for each arg in __VA_ARGS__. This |
153 | // is implemented inside detail::serialize_kv_impl. |
154 | #define TI_IO(...) \ |
155 | do { \ |
156 | constexpr size_t kDelimN = detail::count_delim(#__VA_ARGS__, ','); \ |
157 | constexpr auto kSplitStrs = \ |
158 | detail::StrDelimSplitter<kDelimN>::make(#__VA_ARGS__, ','); \ |
159 | detail::serialize_kv_impl(serializer, kSplitStrs, __VA_ARGS__); \ |
160 | } while (0) |
161 | |
162 | #define TI_SERIALIZER_IS(T) \ |
163 | (std::is_same<typename std::remove_reference<decltype(serializer)>::type, \ |
164 | T>()) |
165 | |
166 | #if !defined(TI_ARCH_x86) |
167 | static_assert( |
168 | sizeof(std::size_t) == sizeof(uint64_t), |
169 | "sizeof(std::size_t) should be 8. Try compiling with 64bit mode." ); |
170 | #endif |
171 | template <typename T, typename S> |
172 | struct IO { |
173 | using implemented = std::false_type; |
174 | }; |
175 | |
176 | class Serializer { |
177 | public: |
178 | template <typename T, std::size_t n> |
179 | using TArray = T[n]; |
180 | template <typename T, std::size_t n> |
181 | using StdTArray = std::array<T, n>; |
182 | |
183 | std::unordered_map<std::size_t, void *> assets; |
184 | |
185 | template <typename T, typename T_ = typename type::remove_cvref_t<T>> |
186 | static T_ &get_writable(T &&t) { |
187 | return *const_cast<T_ *>(&t); |
188 | } |
189 | |
190 | template <typename T> |
191 | struct has_io { |
192 | template <typename T_> |
193 | static constexpr auto helper(T_ *) -> std::is_same< |
194 | decltype((std::declval<T_>().io(std::declval<Serializer &>()))), |
195 | void>; |
196 | |
197 | template <typename> |
198 | static constexpr auto helper(...) -> std::false_type; |
199 | |
200 | public: |
201 | using T__ = typename type::remove_cvref_t<T>; |
202 | using type = decltype(helper<T__>(nullptr)); |
203 | static constexpr bool value = type::value; |
204 | }; |
205 | |
206 | template <typename T> |
207 | struct has_free_io { |
208 | template <typename T_> |
209 | static constexpr auto helper(T_ *) -> |
210 | typename IO<T_, Serializer>::implemented; |
211 | |
212 | template <typename> |
213 | static constexpr auto helper(...) -> std::false_type; |
214 | |
215 | public: |
216 | using T__ = typename type::remove_cvref_t<T>; |
217 | using type = decltype(helper<T__>(nullptr)); |
218 | static constexpr bool value = type::value; |
219 | }; |
220 | }; |
221 | |
222 | inline std::vector<uint8_t> read_data_from_file(const std::string &fn) { |
223 | std::vector<uint8_t> data; |
224 | std::FILE *f = fopen(fn.c_str(), "rb" ); |
225 | if (f == nullptr) { |
226 | TI_DEBUG("Cannot open file: {}" , fn); |
227 | return std::vector<uint8_t>(); |
228 | } |
229 | if (ends_with(fn, ".zip" )) { |
230 | std::fclose(f); |
231 | // Read zip file, e.g. particles.tcb.zip |
232 | return zip::read(fn); |
233 | } else { |
234 | // Read uncompressed file, e.g. particles.tcb |
235 | assert(f != nullptr); |
236 | std::size_t length = 0; |
237 | while (true) { |
238 | size_t limit = 1 << 8; |
239 | data.resize(data.size() + limit); |
240 | void *ptr = reinterpret_cast<void *>(&data[length]); |
241 | size_t length_tmp = fread(ptr, sizeof(uint8_t), limit, f); |
242 | length += length_tmp; |
243 | if (length_tmp < limit) { |
244 | break; |
245 | } |
246 | } |
247 | std::fclose(f); |
248 | data.resize(length); |
249 | return data; |
250 | } |
251 | } |
252 | |
253 | inline void write_data_to_file(const std::string &fn, |
254 | uint8_t *data, |
255 | std::size_t size) { |
256 | std::FILE *f = fopen(fn.c_str(), "wb" ); |
257 | if (f == nullptr) { |
258 | TI_ERROR("Cannot open file [{}] for writing. (Does the directory exist?)" , |
259 | fn); |
260 | assert(f != nullptr); |
261 | } |
262 | if (ends_with(fn, ".tcb.zip" )) { |
263 | std::fclose(f); |
264 | zip::write(fn, data, size); |
265 | } else if (ends_with(fn, ".tcb" )) { |
266 | fwrite(data, sizeof(uint8_t), size, f); |
267 | std::fclose(f); |
268 | } else { |
269 | TI_ERROR("File must end with .tcb or .tcb.zip. [Filename = {}]" , fn); |
270 | } |
271 | } |
272 | |
273 | template <bool writing> |
274 | class BinarySerializer : public Serializer { |
275 | private: |
276 | template <typename T> |
277 | inline static constexpr bool is_elementary_type_v = |
278 | !has_io<T>::value && !std::is_pointer<T>::value && !std::is_enum_v<T> && |
279 | std::is_pod_v<T>; |
280 | |
281 | public: |
282 | std::vector<uint8_t> data; |
283 | uint8_t *c_data; |
284 | |
285 | std::size_t head; |
286 | std::size_t preserved; |
287 | |
288 | using Base = Serializer; |
289 | using Base::assets; |
290 | |
291 | template <bool writing_ = writing> |
292 | typename std::enable_if<!writing_, bool>::type initialize( |
293 | const std::string &fn) { |
294 | data = read_data_from_file(fn); |
295 | if (data.size() == 0) { |
296 | return false; |
297 | } |
298 | c_data = reinterpret_cast<uint8_t *>(&data[0]); |
299 | head = sizeof(std::size_t); |
300 | return true; |
301 | } |
302 | |
303 | void write_to_file(const std::string &fn) { |
304 | void *ptr = c_data; |
305 | if (!ptr) { |
306 | assert(!data.empty()); |
307 | ptr = &data[0]; |
308 | } |
309 | write_data_to_file(fn, reinterpret_cast<uint8_t *>(ptr), head); |
310 | } |
311 | |
312 | template <bool writing_ = writing> |
313 | typename std::enable_if<writing_, bool>::type initialize( |
314 | std::size_t preserved_ = std::size_t(0), |
315 | void *c_data = nullptr) { |
316 | std::size_t n = 0; |
317 | head = 0; |
318 | if (preserved_ != 0) { |
319 | TI_TRACE("preserved = {}" , preserved_); |
320 | // Preserved mode |
321 | this->preserved = preserved_; |
322 | assert(c_data != nullptr); |
323 | this->c_data = (uint8_t *)c_data; |
324 | } else { |
325 | // otherwise use a std::vector<uint8_t> |
326 | this->preserved = 0; |
327 | this->c_data = nullptr; |
328 | } |
329 | this->operator()("" , n); |
330 | return true; |
331 | } |
332 | |
333 | template <bool writing_ = writing> |
334 | typename std::enable_if<!writing_, void>::type initialize( |
335 | void *raw_data, |
336 | std::size_t preserved_ = std::size_t(0)) { |
337 | if (preserved_ != 0) { |
338 | assert(raw_data == nullptr); |
339 | data.resize(preserved_); |
340 | c_data = &data[0]; |
341 | } else { |
342 | assert(raw_data != nullptr); |
343 | c_data = reinterpret_cast<uint8_t *>(raw_data); |
344 | } |
345 | head = sizeof(std::size_t); |
346 | preserved = 0; |
347 | } |
348 | |
349 | template <bool writing_ = writing> |
350 | typename std::enable_if<!writing_, std::size_t>::type retrieve_length() { |
351 | return *reinterpret_cast<std::size_t *>(c_data); |
352 | } |
353 | |
354 | void finalize() { |
355 | if constexpr (writing) { |
356 | if (c_data) { |
357 | *reinterpret_cast<std::size_t *>(&c_data[0]) = head; |
358 | } else { |
359 | *reinterpret_cast<std::size_t *>(&data[0]) = head; |
360 | } |
361 | } else { |
362 | assert(head == retrieve_length()); |
363 | } |
364 | } |
365 | |
366 | template <typename T> |
367 | void operator()(const char *, const T &val) { |
368 | this->process(val); |
369 | } |
370 | |
371 | template <typename T> |
372 | void operator()(const T &val) { |
373 | this->process(val); |
374 | } |
375 | |
376 | private: |
377 | // std::string |
378 | void process(const std::string &val_) { |
379 | auto &val = get_writable(val_); |
380 | if (writing) { |
381 | std::vector<char> val_vector(val.begin(), val.end()); |
382 | this->process(val_vector); |
383 | } else { |
384 | std::vector<char> val_vector; |
385 | this->process(val_vector); |
386 | val = std::string(val_vector.begin(), val_vector.end()); |
387 | } |
388 | } |
389 | |
390 | // C-array |
391 | template <typename T, std::size_t n> |
392 | void process(const TArray<T, n> &val) { |
393 | if (writing) { |
394 | for (std::size_t i = 0; i < n; i++) { |
395 | this->process(val[i]); |
396 | } |
397 | } else { |
398 | // TODO: why do I have to let it write to tmp, otherwise I get Sig Fault? |
399 | // Take care of std::vector<bool> ... |
400 | using Traw = typename type::remove_cvref_t<T>; |
401 | std::vector< |
402 | std::conditional_t<std::is_same<Traw, bool>::value, uint8, Traw>> |
403 | tmp(n); |
404 | for (std::size_t i = 0; i < n; i++) { |
405 | this->process(tmp[i]); |
406 | } |
407 | std::memcpy(const_cast<typename std::remove_cv<T>::type *>(val), &tmp[0], |
408 | sizeof(tmp[0]) * tmp.size()); |
409 | } |
410 | } |
411 | |
412 | // Elementary data types |
413 | template <typename T> |
414 | typename std::enable_if_t<is_elementary_type_v<T>, void> process( |
415 | const T &val) { |
416 | static_assert(!std::is_reference<T>::value, "T cannot be reference" ); |
417 | static_assert(!std::is_const<T>::value, "T cannot be const" ); |
418 | static_assert(!std::is_volatile<T>::value, "T cannot be volatile" ); |
419 | static_assert(!std::is_pointer<T>::value, "T cannot be pointer" ); |
420 | if (writing) { |
421 | std::size_t new_size = head + sizeof(T); |
422 | if (c_data) { |
423 | if (new_size > preserved) { |
424 | TI_CRITICAL("Preserved Buffer (size {}) Overflow." , preserved); |
425 | } |
426 | //*reinterpret_cast<typename type::remove_cvref_t<T> *>(&c_data[head]) = |
427 | // val; |
428 | std::memcpy(&c_data[head], &val, sizeof(T)); |
429 | } else { |
430 | data.resize(new_size); |
431 | //*reinterpret_cast<typename type::remove_cvref_t<T> *>(&data[head]) = |
432 | // val; |
433 | std::memcpy(&data[head], &val, sizeof(T)); |
434 | } |
435 | } else { |
436 | // get_writable(val) = |
437 | // *reinterpret_cast<typename std::remove_reference<T>::type *>( |
438 | // &c_data[head]); |
439 | std::memcpy(&get_writable(val), &c_data[head], sizeof(T)); |
440 | } |
441 | head += sizeof(T); |
442 | } |
443 | |
444 | template <typename T> |
445 | std::enable_if_t<has_io<T>::value, void> process(const T &val) { |
446 | val.io(*this); |
447 | } |
448 | |
449 | // Unique Pointers to non-taichi-unit Types |
450 | template <typename T> |
451 | typename std::enable_if<!type::is_unit<T>::value, void>::type process( |
452 | const std::unique_ptr<T> &val_) { |
453 | auto &val = get_writable(val_); |
454 | if (writing) { |
455 | this->process(ptr_to_int(val.get())); |
456 | if (val.get() != nullptr) { |
457 | this->process(*val); |
458 | // Just for checking future raw pointers |
459 | assets.insert(std::make_pair(ptr_to_int(val.get()), val.get())); |
460 | } |
461 | } else { |
462 | std::size_t original_addr; |
463 | this->process(original_addr); |
464 | if (original_addr != 0) { |
465 | val = std::make_unique<T>(); |
466 | assets.insert(std::make_pair(original_addr, val.get())); |
467 | this->process(*val); |
468 | } |
469 | } |
470 | } |
471 | |
472 | template <typename T> |
473 | std::size_t ptr_to_int(T *t) { |
474 | return reinterpret_cast<std::size_t>(t); |
475 | } |
476 | |
477 | // Unique Pointers to taichi-unit Types |
478 | template <typename T> |
479 | typename std::enable_if<type::is_unit<T>::value, void>::type process( |
480 | const std::unique_ptr<T> &val_) { |
481 | auto &val = get_writable(val_); |
482 | if (writing) { |
483 | this->process(val->get_name()); |
484 | this->process(ptr_to_int(val.get())); |
485 | if (val.get() != nullptr) { |
486 | val->binary_io(nullptr, *this); |
487 | // Just for checking future raw pointers |
488 | assets.insert(std::make_pair(ptr_to_int(val.get()), val.get())); |
489 | } |
490 | } else { |
491 | std::string name; |
492 | std::size_t original_addr; |
493 | this->process(name); |
494 | this->process(original_addr); |
495 | if (original_addr != 0) { |
496 | val = create_instance_unique<T>(name); |
497 | assets.insert(std::make_pair(original_addr, val.get())); |
498 | val->binary_io(nullptr, *this); |
499 | } |
500 | } |
501 | } |
502 | |
503 | // Raw pointers (no ownership) |
504 | template <typename T> |
505 | typename std::enable_if<std::is_pointer<T>::value, void>::type process( |
506 | const T &val_) { |
507 | auto &val = get_writable(val_); |
508 | if (writing) { |
509 | this->process(ptr_to_int(val)); |
510 | if (val != nullptr) { |
511 | TI_ASSERT_INFO(assets.find(ptr_to_int(val)) != assets.end(), |
512 | "Cannot find the address with a smart pointer pointing " |
513 | "to. Make sure the smart pointer is serialized before " |
514 | "the raw pointer." ); |
515 | } |
516 | } else { |
517 | std::size_t val_ptr = 0; |
518 | this->process(val_ptr); |
519 | if (val_ptr != 0) { |
520 | TI_ASSERT(assets.find(val_ptr) != assets.end()); |
521 | val = reinterpret_cast<typename std::remove_pointer<T>::type *>( |
522 | assets[val_ptr]); |
523 | } |
524 | } |
525 | } |
526 | |
527 | // enum class |
528 | template <typename T> |
529 | typename std::enable_if<std::is_enum_v<T>, void>::type process(const T &val) { |
530 | using UT = std::underlying_type_t<T>; |
531 | // https://stackoverflow.com/a/62688905/12003165 |
532 | if constexpr (writing) { |
533 | this->process(static_cast<UT>(val)); |
534 | } else { |
535 | auto &wval = get_writable(val); |
536 | UT &underlying_wval = reinterpret_cast<UT &>(wval); |
537 | this->process(underlying_wval); |
538 | } |
539 | } |
540 | |
541 | // std::vector |
542 | template <typename T> |
543 | void process(const std::vector<T> &val_) { |
544 | auto &val = get_writable(val_); |
545 | if (writing) { |
546 | this->process(val.size()); |
547 | } else { |
548 | std::size_t n = 0; |
549 | this->process(n); |
550 | val.resize(n); |
551 | } |
552 | for (std::size_t i = 0; i < val.size(); i++) { |
553 | this->process(val[i]); |
554 | } |
555 | } |
556 | |
557 | // std::pair |
558 | template <typename T, typename G> |
559 | void process(const std::pair<T, G> &val) { |
560 | this->process(val.first); |
561 | this->process(val.second); |
562 | } |
563 | |
564 | // std::map |
565 | template <typename K, typename V> |
566 | void process(const std::map<K, V> &val) { |
567 | handle_associative_container(val); |
568 | } |
569 | |
570 | // std::unordered_map |
571 | template <typename K, typename V> |
572 | void process(const std::unordered_map<K, V> &val) { |
573 | handle_associative_container(val); |
574 | } |
575 | |
576 | // std::optional |
577 | template <typename T> |
578 | void process(const std::optional<T> &val) { |
579 | if constexpr (writing) { |
580 | this->process(val.has_value()); |
581 | if (val.has_value()) { |
582 | this->process(val.value()); |
583 | } |
584 | } else { |
585 | bool has_value{false}; |
586 | this->process(has_value); |
587 | auto &wval = get_writable(val); |
588 | if (!has_value) { |
589 | wval.reset(); |
590 | } else { |
591 | T new_val; |
592 | this->process(new_val); |
593 | wval = std::move(new_val); |
594 | } |
595 | } |
596 | } |
597 | |
598 | template <typename M> |
599 | void handle_associative_container(const M &val) { |
600 | if constexpr (writing) { |
601 | this->process(val.size()); |
602 | for (auto &iter : val) { |
603 | auto first = iter.first; |
604 | this->process(first); |
605 | this->process(iter.second); |
606 | } |
607 | } else { |
608 | auto &wval = get_writable(val); |
609 | wval.clear(); |
610 | std::size_t n = 0; |
611 | this->process(n); |
612 | for (std::size_t i = 0; i < n; i++) { |
613 | typename M::value_type record; |
614 | this->process(record); |
615 | wval.insert(std::move(record)); |
616 | } |
617 | } |
618 | } |
619 | }; |
620 | |
621 | using BinaryOutputSerializer = BinarySerializer<true>; |
622 | using BinaryInputSerializer = BinarySerializer<false>; |
623 | |
624 | // Serialize to JSON format |
625 | class TextSerializer : public Serializer { |
626 | public: |
627 | std::string data; |
628 | void print() const { |
629 | std::cout << data << std::endl; |
630 | } |
631 | |
632 | void write_to_file(const std::string &file_name) { |
633 | std::ofstream fs(file_name); |
634 | fs << data; |
635 | fs.close(); |
636 | } |
637 | |
638 | private: |
639 | int indent_; |
640 | static constexpr int indent_width = 2; |
641 | bool first_line_; |
642 | |
643 | template <typename T> |
644 | inline static constexpr bool is_elementary_type_v = |
645 | !has_io<T>::value && !has_free_io<T>::value && !std::is_enum_v<T> && |
646 | std::is_pod_v<T>; |
647 | |
648 | public: |
649 | TextSerializer() { |
650 | indent_ = 0; |
651 | first_line_ = false; |
652 | } |
653 | |
654 | template <typename T> |
655 | static std::string serialize(const char *key, const T &t) { |
656 | TextSerializer ser; |
657 | ser(key, t); |
658 | return ser.data; |
659 | } |
660 | |
661 | template <typename T> |
662 | void operator()(const char *key, const T &t, bool append_comma = false) { |
663 | add_key(key); |
664 | process(t); |
665 | if (append_comma) { |
666 | add_raw("," ); |
667 | } |
668 | } |
669 | |
670 | // Entry to make an AOT json file |
671 | template <typename T> |
672 | void serialize_to_json(const char *key, const T &t) { |
673 | add_raw("{" ); |
674 | (*this)(key, t); |
675 | add_raw("}" ); |
676 | } |
677 | |
678 | private: |
679 | void process(const std::string &val) { |
680 | add_raw("\"" + val + "\"" ); |
681 | } |
682 | |
683 | template <typename T, std::size_t n> |
684 | using is_compact = |
685 | typename std::integral_constant<bool, |
686 | std::is_arithmetic<T>::value && (n < 7)>; |
687 | |
688 | // C-array |
689 | template <typename T, std::size_t n> |
690 | std::enable_if_t<is_compact<T, n>::value, void> process( |
691 | const TArray<T, n> &val) { |
692 | std::stringstream ss; |
693 | ss << "{" ; |
694 | for (std::size_t i = 0; i < n; i++) { |
695 | ss << val[i]; |
696 | if (i != n - 1) { |
697 | ss << ", " ; |
698 | } |
699 | } |
700 | ss << "}" ; |
701 | add_raw(ss.str()); |
702 | } |
703 | |
704 | // C-array |
705 | template <typename T, std::size_t n> |
706 | std::enable_if_t<!is_compact<T, n>::value, void> process( |
707 | const TArray<T, n> &val) { |
708 | add_raw("{" ); |
709 | indent_++; |
710 | for (std::size_t i = 0; i < n; i++) { |
711 | add_key(std::to_string(i).c_str()); |
712 | process(val[i]); |
713 | if (i != n - 1) { |
714 | add_raw("," ); |
715 | } |
716 | } |
717 | indent_--; |
718 | add_raw("}" ); |
719 | } |
720 | |
721 | // std::array |
722 | template <typename T, std::size_t n> |
723 | std::enable_if_t<is_compact<T, n>::value, void> process( |
724 | const StdTArray<T, n> &val) { |
725 | std::stringstream ss; |
726 | ss << "{" ; |
727 | for (std::size_t i = 0; i < n; i++) { |
728 | ss << val[i]; |
729 | if (i != n - 1) { |
730 | ss << ", " ; |
731 | } |
732 | } |
733 | ss << "}" ; |
734 | add_raw(ss.str()); |
735 | } |
736 | |
737 | // std::array |
738 | template <typename T, std::size_t n> |
739 | std::enable_if_t<!is_compact<T, n>::value, void> process( |
740 | const StdTArray<T, n> &val) { |
741 | add_raw("{" ); |
742 | indent_++; |
743 | for (std::size_t i = 0; i < n; i++) { |
744 | add_key(std::to_string(i).c_str()); |
745 | process(val[i]); |
746 | if (i != n - 1) { |
747 | add_raw("," ); |
748 | } |
749 | } |
750 | indent_--; |
751 | add_raw("}" ); |
752 | } |
753 | |
754 | // Elementary data types |
755 | template <typename T> |
756 | std::enable_if_t<is_elementary_type_v<T>, void> process(const T &val) { |
757 | static_assert(!has_io<T>::value, "" ); |
758 | std::stringstream ss; |
759 | ss << std::boolalpha << val; |
760 | add_raw(ss.str()); |
761 | } |
762 | |
763 | template <typename T> |
764 | std::enable_if_t<has_io<T>::value, void> process(const T &val) { |
765 | add_raw("{" ); |
766 | indent_++; |
767 | val.io(*this); |
768 | indent_--; |
769 | add_raw("}" ); |
770 | } |
771 | |
772 | template <typename T> |
773 | std::enable_if_t<has_free_io<T>::value, void> process(const T &val) { |
774 | add_raw("{" ); |
775 | indent_++; |
776 | IO<typename type::remove_cvref_t<T>, decltype(*this)>()(*this, val); |
777 | indent_--; |
778 | add_raw("}" ); |
779 | } |
780 | |
781 | template <typename T> |
782 | std::enable_if_t<std::is_enum_v<T>, void> process(const T &val) { |
783 | using UT = std::underlying_type_t<T>; |
784 | process(static_cast<UT>(val)); |
785 | } |
786 | |
787 | template <typename T> |
788 | void process(const std::vector<T> &val) { |
789 | add_raw("[" ); |
790 | indent_++; |
791 | for (std::size_t i = 0; i < val.size(); i++) { |
792 | process(val[i]); |
793 | if (i < val.size() - 1) { |
794 | add_raw("," ); |
795 | } |
796 | } |
797 | indent_--; |
798 | add_raw("]" ); |
799 | } |
800 | |
801 | template <typename T, typename G> |
802 | void process(const std::pair<T, G> &val) { |
803 | add_raw("[" ); |
804 | indent_++; |
805 | process("first" , val.first); |
806 | add_raw(", " ); |
807 | process("second" , val.second); |
808 | indent_--; |
809 | add_raw("]" ); |
810 | } |
811 | |
812 | // std::map |
813 | template <typename K, typename V> |
814 | void process(const std::map<K, V> &val) { |
815 | handle_associative_container(val); |
816 | } |
817 | |
818 | // std::unordered_map |
819 | template <typename K, typename V> |
820 | void process(const std::unordered_map<K, V> &val) { |
821 | handle_associative_container(val); |
822 | } |
823 | |
824 | // std::optional |
825 | template <typename T> |
826 | void process(const std::optional<T> &val) { |
827 | add_raw("{" ); |
828 | indent_++; |
829 | add_key("has_value" ); |
830 | process(val.has_value()); |
831 | if (val.has_value()) { |
832 | add_raw("," ); |
833 | add_key("value" ); |
834 | process(val.value()); |
835 | } |
836 | indent_--; |
837 | add_raw("}" ); |
838 | } |
839 | |
840 | template <typename M> |
841 | void handle_associative_container(const M &val) { |
842 | add_raw("{" ); |
843 | indent_++; |
844 | for (auto iter = val.begin(); iter != val.end(); iter++) { |
845 | auto first = iter->first; |
846 | bool is_string = typeid(first) == typeid(std::string); |
847 | // Non-string keys must be wrapped by quotes. |
848 | if (!is_string) { |
849 | add_raw("\"" ); |
850 | } |
851 | process(first); |
852 | if (!is_string) { |
853 | add_raw("\"" ); |
854 | } |
855 | add_raw(": " ); |
856 | process(iter->second); |
857 | if (std::next(iter) != val.end()) { |
858 | add_raw("," ); |
859 | } |
860 | } |
861 | indent_--; |
862 | add_raw("}" ); |
863 | } |
864 | |
865 | void add_raw(const std::string &str) { |
866 | data += str; |
867 | } |
868 | |
869 | void add_key(const std::string &key) { |
870 | if (first_line_) { |
871 | first_line_ = false; |
872 | } else { |
873 | data += "\n" ; |
874 | } |
875 | data += std::string(indent_width * indent_, ' ') + "\"" + key + "\"" ; |
876 | |
877 | add_raw(": " ); |
878 | } |
879 | }; |
880 | |
881 | template <typename T> |
882 | typename std::enable_if<Serializer::has_io<T>::value, std::ostream &>::type |
883 | operator<<(std::ostream &os, const T &t) { |
884 | os << TextSerializer::serialize("value" , t); |
885 | return os; |
886 | } |
887 | |
888 | // Returns true if deserialization succeeded. |
889 | template <typename T> |
890 | bool read_from_binary(T &t, |
891 | const void *bin, |
892 | std::size_t len, |
893 | bool match_all = true) { |
894 | BinaryInputSerializer reader; |
895 | reader.initialize(const_cast<void *>(bin)); |
896 | if (len != reader.retrieve_length()) { |
897 | return false; |
898 | } |
899 | reader(t); |
900 | auto head = reader.head; |
901 | return match_all ? head == len : head <= len; |
902 | } |
903 | |
904 | template <typename T> |
905 | bool read_from_binary_file(T &t, const std::string &file_name) { |
906 | BinaryInputSerializer reader; |
907 | if (!reader.initialize(file_name)) { |
908 | return false; |
909 | } |
910 | reader(t); |
911 | reader.finalize(); |
912 | return true; |
913 | } |
914 | |
915 | template <typename T> |
916 | void write_to_binary_file(const T &t, const std::string &file_name) { |
917 | BinaryOutputSerializer writer; |
918 | writer.initialize(); |
919 | writer(t); |
920 | writer.finalize(); |
921 | writer.write_to_file(file_name); |
922 | } |
923 | |
924 | // Compile-Time Tests |
925 | static_assert(std::is_same<decltype(Serializer::get_writable( |
926 | std::declval<const std::vector<int> &>())), |
927 | std::vector<int> &>(), |
928 | "" ); |
929 | |
930 | static_assert( |
931 | std::is_same< |
932 | decltype(Serializer::get_writable( |
933 | std::declval<const std::vector<std::unique_ptr<int>> &>())), |
934 | std::vector<std::unique_ptr<int>> &>(), |
935 | "" ); |
936 | |
937 | #ifdef TI_INCLUDED |
938 | } // namespace taichi |
939 | #endif |
940 | |