1 | /******************************************************************************* |
2 | Copyright (c) The Taichi Authors (2016- ). All Rights Reserved. |
3 | The use of this software is governed by the LICENSE file. |
4 | *******************************************************************************/ |
5 | |
6 | #pragma once |
7 | |
8 | #include <map> |
9 | #include <string> |
10 | #include <cstdio> |
11 | #include <iostream> |
12 | #include <fstream> |
13 | #include <vector> |
14 | #include <sstream> |
15 | #include <typeinfo> |
16 | |
17 | #include "taichi/common/core.h" |
18 | #include "taichi/math/math.h" |
19 | |
20 | namespace taichi { |
21 | |
22 | // Declare and then load |
23 | // Load to `this` |
24 | #define TI_LOAD_CONFIG(name, default_val) \ |
25 | this->name = config.get(#name, default_val) |
26 | |
27 | class Dict { |
28 | private: |
29 | std::map<std::string, std::string> data_; |
30 | |
31 | public: |
32 | TI_IO_DEF(data_); |
33 | |
34 | Dict() = default; |
35 | |
36 | template <typename T> |
37 | Dict(const std::string &key, const T &value) { |
38 | this->set(key, value); |
39 | } |
40 | |
41 | std::vector<std::string> get_keys() const { |
42 | std::vector<std::string> keys; |
43 | for (auto it = data_.begin(); it != data_.end(); ++it) { |
44 | keys.push_back(it->first); |
45 | } |
46 | return keys; |
47 | } |
48 | |
49 | void clear() { |
50 | data_.clear(); |
51 | } |
52 | |
53 | template <typename V> |
54 | typename std::enable_if_t<(!type::is_VectorND<V>() && |
55 | !std::is_reference<V>::value && |
56 | !std::is_pointer<V>::value), |
57 | V> |
58 | get(std::string key) const; |
59 | |
60 | static bool is_string_integral(const std::string &str) { |
61 | // TODO: make it correct |
62 | if (str.find('.') != std::string::npos) { |
63 | return false; |
64 | } |
65 | if (str.find('e') != std::string::npos) { |
66 | return false; |
67 | } |
68 | if (str.find('E') != std::string::npos) { |
69 | return false; |
70 | } |
71 | return true; |
72 | } |
73 | |
74 | void check_string_integral(const std::string &str) const { |
75 | if (!is_string_integral(str)) { |
76 | TI_ERROR( |
77 | "Getting integral value out of non-integral string '{}' is not " |
78 | "allowed." , |
79 | str); |
80 | } |
81 | } |
82 | |
83 | void check_value_integral(const std::string &key) const { |
84 | auto str = get_string(key); |
85 | check_string_integral(str); |
86 | } |
87 | |
88 | template < |
89 | typename V, |
90 | typename std::enable_if<(type::is_VectorND<V>()), V>::type * = nullptr> |
91 | V get(std::string key) const { |
92 | constexpr int N = V::dim; |
93 | using T = typename V::ScalarType; |
94 | |
95 | std::string str = this->get_string(key); |
96 | std::string temp; |
97 | if (str[0] == '(') { |
98 | temp = "(" ; |
99 | } else if (str[0] == '[') { |
100 | temp = "[" ; |
101 | } |
102 | if (std::is_integral<T>()) { |
103 | check_string_integral(str); |
104 | } |
105 | for (int i = 0; i < N; i++) { |
106 | std::string placeholder; |
107 | if (std::is_same<T, float32>()) { |
108 | placeholder = "%f" ; |
109 | } else if (std::is_same<T, float64>()) { |
110 | placeholder = "%lf" ; |
111 | } else if (std::is_same<T, int32>()) { |
112 | placeholder = "%d" ; |
113 | } else if (std::is_same<T, uint32>()) { |
114 | placeholder = "%u" ; |
115 | } else if (std::is_same<T, int64>()) { |
116 | #ifdef WIN32 |
117 | placeholder = "%I64d" ; |
118 | #else |
119 | placeholder = "%lld" ; |
120 | #endif |
121 | } else if (std::is_same<T, uint64>()) { |
122 | #ifdef WIN32 |
123 | placeholder = "%I64u" ; |
124 | #else |
125 | placeholder = "%llu" ; |
126 | #endif |
127 | } else { |
128 | assert(false); |
129 | } |
130 | temp += placeholder; |
131 | if (i != N - 1) { |
132 | temp += "," ; |
133 | } |
134 | } |
135 | if (str[0] == '(') { |
136 | temp += ")" ; |
137 | } else if (str[0] == '[') { |
138 | temp += "]" ; |
139 | } |
140 | VectorND<N, T> ret; |
141 | if (N == 1) { |
142 | sscanf(str.c_str(), temp.c_str(), &ret[0]); |
143 | } else if (N == 2) { |
144 | sscanf(str.c_str(), temp.c_str(), &ret[0], &ret[1]); |
145 | } else if (N == 3) { |
146 | sscanf(str.c_str(), temp.c_str(), &ret[0], &ret[1], &ret[2]); |
147 | } else if (N == 4) { |
148 | sscanf(str.c_str(), temp.c_str(), &ret[0], &ret[1], &ret[2], &ret[3]); |
149 | } |
150 | return ret; |
151 | } |
152 | |
153 | std::string get(std::string key, const char *default_val) const; |
154 | |
155 | template <typename T> |
156 | T get(std::string key, const T &default_val) const; |
157 | |
158 | bool has_key(std::string key) const { |
159 | return data_.find(key) != data_.end(); |
160 | } |
161 | |
162 | std::vector<std::string> get_string_arr(std::string key) const { |
163 | std::string str = get_string(key); |
164 | std::vector<std::string> strs = split_string(str, "," ); |
165 | for (auto &s : strs) { |
166 | s = trim_string(s); |
167 | } |
168 | return strs; |
169 | } |
170 | |
171 | template <typename T> |
172 | T *get_ptr(std::string key) const { |
173 | std::string val = get_string(key); |
174 | std::stringstream ss(val); |
175 | std::string t; |
176 | int64 ptr_ll; |
177 | std::getline(ss, t, '\t'); |
178 | ss >> ptr_ll; |
179 | TI_ASSERT_INFO(t == typeid(T).name(), |
180 | "Pointer type mismatch: " + t + " and " + typeid(T).name()); |
181 | return reinterpret_cast<T *>(ptr_ll); |
182 | } |
183 | |
184 | template <typename T> |
185 | std::enable_if_t<std::is_pointer<T>::value, std::remove_pointer_t<T>> get( |
186 | std::string key) const { |
187 | return get_ptr<std::remove_pointer_t<T>>(key); |
188 | } |
189 | |
190 | template <typename T> |
191 | std::enable_if_t<std::is_reference<T>::value, std::remove_reference_t<T>> |
192 | &get(std::string key) const { |
193 | return *get_ptr<std::remove_reference_t<T>>(key); |
194 | } |
195 | |
196 | template <typename T> |
197 | T *get_ptr(std::string key, T *default_value) const { |
198 | if (has_key(key)) { |
199 | return get_ptr<T>(key); |
200 | } else { |
201 | return default_value; |
202 | } |
203 | } |
204 | |
205 | template <typename T> |
206 | Dict &set(std::string name, T val) { |
207 | std::stringstream ss; |
208 | ss << val; |
209 | data_[name] = ss.str(); |
210 | return *this; |
211 | } |
212 | |
213 | Dict &set(std::string name, const char *val) { |
214 | std::stringstream ss; |
215 | ss << val; |
216 | data_[name] = ss.str(); |
217 | return *this; |
218 | } |
219 | |
220 | Dict &set(std::string name, const Vector2 &val) { |
221 | std::stringstream ss; |
222 | ss << "(" << val.x << "," << val.y << ")" ; |
223 | data_[name] = ss.str(); |
224 | return *this; |
225 | } |
226 | |
227 | Dict &set(std::string name, const Vector3 &val) { |
228 | std::stringstream ss; |
229 | ss << "(" << val.x << "," << val.y << "," << val.z << ")" ; |
230 | data_[name] = ss.str(); |
231 | return *this; |
232 | } |
233 | |
234 | Dict &set(std::string name, const Vector4 &val) { |
235 | std::stringstream ss; |
236 | ss << "(" << val.x << "," << val.y << "," << val.z << "," << val.w << ")" ; |
237 | data_[name] = ss.str(); |
238 | return *this; |
239 | } |
240 | |
241 | Dict &set(std::string name, const Vector2i &val) { |
242 | std::stringstream ss; |
243 | ss << "(" << val.x << "," << val.y << ")" ; |
244 | data_[name] = ss.str(); |
245 | return *this; |
246 | } |
247 | |
248 | Dict &set(std::string name, const Vector3i &val) { |
249 | std::stringstream ss; |
250 | ss << "(" << val.x << "," << val.y << "," << val.z << ")" ; |
251 | data_[name] = ss.str(); |
252 | return *this; |
253 | } |
254 | |
255 | Dict &set(std::string name, const Vector4i &val) { |
256 | std::stringstream ss; |
257 | ss << "(" << val.x << "," << val.y << "," << val.z << "," << val.w << ")" ; |
258 | data_[name] = ss.str(); |
259 | return *this; |
260 | } |
261 | |
262 | template <typename T> |
263 | static std::string get_ptr_string(T *ptr) { |
264 | std::stringstream ss; |
265 | ss << typeid(T).name() << "\t" << reinterpret_cast<uint64>(ptr); |
266 | return ss.str(); |
267 | } |
268 | |
269 | template <typename T> |
270 | Dict &set(std::string name, T *const ptr) { |
271 | data_[name] = get_ptr_string(ptr); |
272 | return *this; |
273 | } |
274 | |
275 | std::string get_string(std::string key) const { |
276 | if (data_.find(key) == data_.end()) { |
277 | TI_ERROR("No key named '{}' found." , key); |
278 | } |
279 | return data_.find(key)->second; |
280 | } |
281 | |
282 | template <typename T> |
283 | Dict &operator()(const std::string &key, const T &value) { |
284 | this->set(key, value); |
285 | return *this; |
286 | } |
287 | }; |
288 | |
289 | template <> |
290 | inline std::string Dict::get<std::string>(std::string key) const { |
291 | return get_string(key); |
292 | } |
293 | |
294 | template <typename T> |
295 | inline T Dict::get(std::string key, const T &default_val) const { |
296 | if (data_.find(key) == data_.end()) { |
297 | return default_val; |
298 | } else |
299 | return get<T>(key); |
300 | } |
301 | |
302 | inline std::string Dict::get(std::string key, const char *default_val) const { |
303 | if (data_.find(key) == data_.end()) { |
304 | return default_val; |
305 | } else |
306 | return get<std::string>(key); |
307 | } |
308 | |
309 | template <> |
310 | inline float32 Dict::get<float32>(std::string key) const { |
311 | return (float32)std::atof(get_string(key).c_str()); |
312 | } |
313 | |
314 | template <> |
315 | inline float64 Dict::get<float64>(std::string key) const { |
316 | return (float64)std::atof(get_string(key).c_str()); |
317 | } |
318 | |
319 | template <> |
320 | inline int32 Dict::get<int32>(std::string key) const { |
321 | check_value_integral(key); |
322 | return std::atoi(get_string(key).c_str()); |
323 | } |
324 | |
325 | template <> |
326 | inline uint32 Dict::get<uint32>(std::string key) const { |
327 | check_value_integral(key); |
328 | return uint32(std::atoll(get_string(key).c_str())); |
329 | } |
330 | |
331 | template <> |
332 | inline int64 Dict::get<int64>(std::string key) const { |
333 | check_value_integral(key); |
334 | return std::atoll(get_string(key).c_str()); |
335 | } |
336 | |
337 | template <> |
338 | inline uint64 Dict::get<uint64>(std::string key) const { |
339 | check_value_integral(key); |
340 | return std::stoull(get_string(key)); |
341 | } |
342 | |
343 | template <> |
344 | inline bool Dict::get<bool>(std::string key) const { |
345 | std::string s = get_string(key); |
346 | static std::map<std::string, bool> dict{ |
347 | {"true" , true}, {"True" , true}, {"t" , true}, {"1" , true}, |
348 | {"false" , false}, {"False" , false}, {"f" , false}, {"0" , false}, |
349 | }; |
350 | TI_ASSERT_INFO(dict.find(s) != dict.end(), |
351 | "Unknown identifier for bool: " + s); |
352 | return dict[s]; |
353 | } |
354 | |
355 | using Config = Dict; |
356 | |
357 | } // namespace taichi |
358 | |