1/*******************************************************************************
2 Copyright (c) The Taichi Authors (2016- ). All Rights Reserved.
3 The use of this software is governed by the LICENSE file.
4*******************************************************************************/
5
6#pragma once
7
8#include <map>
9#include <string>
10#include <cstdio>
11#include <iostream>
12#include <fstream>
13#include <vector>
14#include <sstream>
15#include <typeinfo>
16
17#include "taichi/common/core.h"
18#include "taichi/math/math.h"
19
20namespace taichi {
21
22// Declare and then load
23// Load to `this`
24#define TI_LOAD_CONFIG(name, default_val) \
25 this->name = config.get(#name, default_val)
26
27class Dict {
28 private:
29 std::map<std::string, std::string> data_;
30
31 public:
32 TI_IO_DEF(data_);
33
34 Dict() = default;
35
36 template <typename T>
37 Dict(const std::string &key, const T &value) {
38 this->set(key, value);
39 }
40
41 std::vector<std::string> get_keys() const {
42 std::vector<std::string> keys;
43 for (auto it = data_.begin(); it != data_.end(); ++it) {
44 keys.push_back(it->first);
45 }
46 return keys;
47 }
48
49 void clear() {
50 data_.clear();
51 }
52
53 template <typename V>
54 typename std::enable_if_t<(!type::is_VectorND<V>() &&
55 !std::is_reference<V>::value &&
56 !std::is_pointer<V>::value),
57 V>
58 get(std::string key) const;
59
60 static bool is_string_integral(const std::string &str) {
61 // TODO: make it correct
62 if (str.find('.') != std::string::npos) {
63 return false;
64 }
65 if (str.find('e') != std::string::npos) {
66 return false;
67 }
68 if (str.find('E') != std::string::npos) {
69 return false;
70 }
71 return true;
72 }
73
74 void check_string_integral(const std::string &str) const {
75 if (!is_string_integral(str)) {
76 TI_ERROR(
77 "Getting integral value out of non-integral string '{}' is not "
78 "allowed.",
79 str);
80 }
81 }
82
83 void check_value_integral(const std::string &key) const {
84 auto str = get_string(key);
85 check_string_integral(str);
86 }
87
88 template <
89 typename V,
90 typename std::enable_if<(type::is_VectorND<V>()), V>::type * = nullptr>
91 V get(std::string key) const {
92 constexpr int N = V::dim;
93 using T = typename V::ScalarType;
94
95 std::string str = this->get_string(key);
96 std::string temp;
97 if (str[0] == '(') {
98 temp = "(";
99 } else if (str[0] == '[') {
100 temp = "[";
101 }
102 if (std::is_integral<T>()) {
103 check_string_integral(str);
104 }
105 for (int i = 0; i < N; i++) {
106 std::string placeholder;
107 if (std::is_same<T, float32>()) {
108 placeholder = "%f";
109 } else if (std::is_same<T, float64>()) {
110 placeholder = "%lf";
111 } else if (std::is_same<T, int32>()) {
112 placeholder = "%d";
113 } else if (std::is_same<T, uint32>()) {
114 placeholder = "%u";
115 } else if (std::is_same<T, int64>()) {
116#ifdef WIN32
117 placeholder = "%I64d";
118#else
119 placeholder = "%lld";
120#endif
121 } else if (std::is_same<T, uint64>()) {
122#ifdef WIN32
123 placeholder = "%I64u";
124#else
125 placeholder = "%llu";
126#endif
127 } else {
128 assert(false);
129 }
130 temp += placeholder;
131 if (i != N - 1) {
132 temp += ",";
133 }
134 }
135 if (str[0] == '(') {
136 temp += ")";
137 } else if (str[0] == '[') {
138 temp += "]";
139 }
140 VectorND<N, T> ret;
141 if (N == 1) {
142 sscanf(str.c_str(), temp.c_str(), &ret[0]);
143 } else if (N == 2) {
144 sscanf(str.c_str(), temp.c_str(), &ret[0], &ret[1]);
145 } else if (N == 3) {
146 sscanf(str.c_str(), temp.c_str(), &ret[0], &ret[1], &ret[2]);
147 } else if (N == 4) {
148 sscanf(str.c_str(), temp.c_str(), &ret[0], &ret[1], &ret[2], &ret[3]);
149 }
150 return ret;
151 }
152
153 std::string get(std::string key, const char *default_val) const;
154
155 template <typename T>
156 T get(std::string key, const T &default_val) const;
157
158 bool has_key(std::string key) const {
159 return data_.find(key) != data_.end();
160 }
161
162 std::vector<std::string> get_string_arr(std::string key) const {
163 std::string str = get_string(key);
164 std::vector<std::string> strs = split_string(str, ",");
165 for (auto &s : strs) {
166 s = trim_string(s);
167 }
168 return strs;
169 }
170
171 template <typename T>
172 T *get_ptr(std::string key) const {
173 std::string val = get_string(key);
174 std::stringstream ss(val);
175 std::string t;
176 int64 ptr_ll;
177 std::getline(ss, t, '\t');
178 ss >> ptr_ll;
179 TI_ASSERT_INFO(t == typeid(T).name(),
180 "Pointer type mismatch: " + t + " and " + typeid(T).name());
181 return reinterpret_cast<T *>(ptr_ll);
182 }
183
184 template <typename T>
185 std::enable_if_t<std::is_pointer<T>::value, std::remove_pointer_t<T>> get(
186 std::string key) const {
187 return get_ptr<std::remove_pointer_t<T>>(key);
188 }
189
190 template <typename T>
191 std::enable_if_t<std::is_reference<T>::value, std::remove_reference_t<T>>
192 &get(std::string key) const {
193 return *get_ptr<std::remove_reference_t<T>>(key);
194 }
195
196 template <typename T>
197 T *get_ptr(std::string key, T *default_value) const {
198 if (has_key(key)) {
199 return get_ptr<T>(key);
200 } else {
201 return default_value;
202 }
203 }
204
205 template <typename T>
206 Dict &set(std::string name, T val) {
207 std::stringstream ss;
208 ss << val;
209 data_[name] = ss.str();
210 return *this;
211 }
212
213 Dict &set(std::string name, const char *val) {
214 std::stringstream ss;
215 ss << val;
216 data_[name] = ss.str();
217 return *this;
218 }
219
220 Dict &set(std::string name, const Vector2 &val) {
221 std::stringstream ss;
222 ss << "(" << val.x << "," << val.y << ")";
223 data_[name] = ss.str();
224 return *this;
225 }
226
227 Dict &set(std::string name, const Vector3 &val) {
228 std::stringstream ss;
229 ss << "(" << val.x << "," << val.y << "," << val.z << ")";
230 data_[name] = ss.str();
231 return *this;
232 }
233
234 Dict &set(std::string name, const Vector4 &val) {
235 std::stringstream ss;
236 ss << "(" << val.x << "," << val.y << "," << val.z << "," << val.w << ")";
237 data_[name] = ss.str();
238 return *this;
239 }
240
241 Dict &set(std::string name, const Vector2i &val) {
242 std::stringstream ss;
243 ss << "(" << val.x << "," << val.y << ")";
244 data_[name] = ss.str();
245 return *this;
246 }
247
248 Dict &set(std::string name, const Vector3i &val) {
249 std::stringstream ss;
250 ss << "(" << val.x << "," << val.y << "," << val.z << ")";
251 data_[name] = ss.str();
252 return *this;
253 }
254
255 Dict &set(std::string name, const Vector4i &val) {
256 std::stringstream ss;
257 ss << "(" << val.x << "," << val.y << "," << val.z << "," << val.w << ")";
258 data_[name] = ss.str();
259 return *this;
260 }
261
262 template <typename T>
263 static std::string get_ptr_string(T *ptr) {
264 std::stringstream ss;
265 ss << typeid(T).name() << "\t" << reinterpret_cast<uint64>(ptr);
266 return ss.str();
267 }
268
269 template <typename T>
270 Dict &set(std::string name, T *const ptr) {
271 data_[name] = get_ptr_string(ptr);
272 return *this;
273 }
274
275 std::string get_string(std::string key) const {
276 if (data_.find(key) == data_.end()) {
277 TI_ERROR("No key named '{}' found.", key);
278 }
279 return data_.find(key)->second;
280 }
281
282 template <typename T>
283 Dict &operator()(const std::string &key, const T &value) {
284 this->set(key, value);
285 return *this;
286 }
287};
288
289template <>
290inline std::string Dict::get<std::string>(std::string key) const {
291 return get_string(key);
292}
293
294template <typename T>
295inline T Dict::get(std::string key, const T &default_val) const {
296 if (data_.find(key) == data_.end()) {
297 return default_val;
298 } else
299 return get<T>(key);
300}
301
302inline std::string Dict::get(std::string key, const char *default_val) const {
303 if (data_.find(key) == data_.end()) {
304 return default_val;
305 } else
306 return get<std::string>(key);
307}
308
309template <>
310inline float32 Dict::get<float32>(std::string key) const {
311 return (float32)std::atof(get_string(key).c_str());
312}
313
314template <>
315inline float64 Dict::get<float64>(std::string key) const {
316 return (float64)std::atof(get_string(key).c_str());
317}
318
319template <>
320inline int32 Dict::get<int32>(std::string key) const {
321 check_value_integral(key);
322 return std::atoi(get_string(key).c_str());
323}
324
325template <>
326inline uint32 Dict::get<uint32>(std::string key) const {
327 check_value_integral(key);
328 return uint32(std::atoll(get_string(key).c_str()));
329}
330
331template <>
332inline int64 Dict::get<int64>(std::string key) const {
333 check_value_integral(key);
334 return std::atoll(get_string(key).c_str());
335}
336
337template <>
338inline uint64 Dict::get<uint64>(std::string key) const {
339 check_value_integral(key);
340 return std::stoull(get_string(key));
341}
342
343template <>
344inline bool Dict::get<bool>(std::string key) const {
345 std::string s = get_string(key);
346 static std::map<std::string, bool> dict{
347 {"true", true}, {"True", true}, {"t", true}, {"1", true},
348 {"false", false}, {"False", false}, {"f", false}, {"0", false},
349 };
350 TI_ASSERT_INFO(dict.find(s) != dict.end(),
351 "Unknown identifier for bool: " + s);
352 return dict[s];
353}
354
355using Config = Dict;
356
357} // namespace taichi
358