1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <iomanip> |
20 | #include <sstream> |
21 | #include <vector> |
22 | |
23 | #include "../../support/str_escape.h" |
24 | #include "../utils.h" |
25 | |
26 | namespace tvm { |
27 | namespace meta_schedule { |
28 | |
29 | void JSONDumps(ObjectRef json_obj, std::ostringstream& os) { |
30 | if (!json_obj.defined()) { |
31 | os << "null" ; |
32 | } else if (const auto* int_imm = json_obj.as<IntImmNode>()) { |
33 | if (int_imm->dtype == DataType::Bool()) { |
34 | if (int_imm->value) { |
35 | os << "true" ; |
36 | } else { |
37 | os << "false" ; |
38 | } |
39 | } else { |
40 | os << int_imm->value; |
41 | } |
42 | } else if (const auto* float_imm = json_obj.as<FloatImmNode>()) { |
43 | os << std::setprecision(20) << float_imm->value; |
44 | } else if (const auto* str = json_obj.as<runtime::StringObj>()) { |
45 | os << '"' << support::StrEscape(str->data, str->size) << '"'; |
46 | } else if (const auto* array = json_obj.as<runtime::ArrayNode>()) { |
47 | os << "[" ; |
48 | int n = array->size(); |
49 | for (int i = 0; i < n; ++i) { |
50 | if (i != 0) { |
51 | os << "," ; |
52 | } |
53 | JSONDumps(array->at(i), os); |
54 | } |
55 | os << "]" ; |
56 | } else if (const auto* dict = json_obj.as<runtime::MapNode>()) { |
57 | int n = dict->size(); |
58 | std::vector<std::pair<String, ObjectRef>> key_values; |
59 | key_values.reserve(n); |
60 | for (const auto& kv : *dict) { |
61 | if (const auto* k = kv.first.as<StringObj>()) { |
62 | key_values.emplace_back(GetRef<String>(k), kv.second); |
63 | } else { |
64 | LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: " |
65 | << kv.first->GetTypeKey(); |
66 | } |
67 | } |
68 | std::sort(key_values.begin(), key_values.end()); |
69 | os << "{" ; |
70 | for (int i = 0; i < n; ++i) { |
71 | const auto& kv = key_values[i]; |
72 | if (i != 0) { |
73 | os << "," ; |
74 | } |
75 | os << '"' << support::StrEscape(kv.first->data, kv.first->size) << '"'; |
76 | os << ":" ; |
77 | JSONDumps(kv.second, os); |
78 | } |
79 | os << "}" ; |
80 | } else if (json_obj->IsInstance<tir::IndexMapNode>()) { |
81 | JSONDumps(String(SaveJSON(json_obj)), os); |
82 | } else { |
83 | LOG(FATAL) << "TypeError: Unsupported type in JSON object: " << json_obj->GetTypeKey(); |
84 | } |
85 | } |
86 | |
87 | std::string JSONDumps(ObjectRef json_obj) { |
88 | std::ostringstream os; |
89 | JSONDumps(json_obj, os); |
90 | return os.str(); |
91 | } |
92 | |
93 | class JSONTokenizer { |
94 | public: |
95 | enum class TokenType : int32_t { |
96 | kEOF = 0, // end of file |
97 | kNull = 1, // null |
98 | kTrue = 2, // true |
99 | kFalse = 3, // false |
100 | kLeftSquare = 4, // [ |
101 | kRightSquare = 5, // ] |
102 | kLeftCurly = 6, // { |
103 | kRightCurly = 7, // } |
104 | kComma = 8, // , |
105 | kColon = 9, // : |
106 | kInteger = 10, // integers |
107 | kFloat = 11, // floating point numbers |
108 | kString = 12, // string |
109 | }; |
110 | |
111 | struct Token { |
112 | TokenType type; |
113 | ObjectRef value{nullptr}; |
114 | }; |
115 | |
116 | explicit JSONTokenizer(const char* st, const char* ed) : cur_(st), end_(ed) {} |
117 | |
118 | Token Next() { |
119 | for (; cur_ != end_ && std::isspace(*cur_); ++cur_) { |
120 | } |
121 | if (cur_ == end_) return Token{TokenType::kEOF}; |
122 | if (NextLeftSquare()) return Token{TokenType::kLeftSquare}; |
123 | if (NextRightSquare()) return Token{TokenType::kRightSquare}; |
124 | if (NextLeftCurly()) return Token{TokenType::kLeftCurly}; |
125 | if (NextRightCurly()) return Token{TokenType::kRightCurly}; |
126 | if (NextComma()) return Token{TokenType::kComma}; |
127 | if (NextColon()) return Token{TokenType::kColon}; |
128 | if (NextNull()) return Token{TokenType::kNull}; |
129 | if (NextTrue()) return Token{TokenType::kTrue}; |
130 | if (NextFalse()) return Token{TokenType::kFalse}; |
131 | Token token; |
132 | if (NextString(&token)) return token; |
133 | if (NextNumber(&token)) return token; |
134 | LOG(FATAL) << "ValueError: Cannot tokenize: " << std::string(cur_, end_); |
135 | throw; |
136 | } |
137 | |
138 | private: |
139 | bool NextLeftSquare() { return NextLiteral('['); } |
140 | bool NextRightSquare() { return NextLiteral(']'); } |
141 | bool NextLeftCurly() { return NextLiteral('{'); } |
142 | bool NextRightCurly() { return NextLiteral('}'); } |
143 | bool NextComma() { return NextLiteral(','); } |
144 | bool NextColon() { return NextLiteral(':'); } |
145 | bool NextNull() { return NextLiteral("null" , 4); } |
146 | bool NextTrue() { return NextLiteral("true" , 4); } |
147 | bool NextFalse() { return NextLiteral("false" , 5); } |
148 | |
149 | bool NextNumber(Token* token) { |
150 | using runtime::DataType; |
151 | bool is_float = false; |
152 | const char* st = cur_; |
153 | for (; cur_ != end_; ++cur_) { |
154 | if (std::isdigit(*cur_) || *cur_ == '+' || *cur_ == '-') { |
155 | continue; |
156 | } else if (*cur_ == '.' || *cur_ == 'e' || *cur_ == 'E') { |
157 | is_float = true; |
158 | } else { |
159 | break; |
160 | } |
161 | } |
162 | if (st == cur_) { |
163 | return false; |
164 | } |
165 | std::string to_parse(st, cur_); |
166 | if (!is_float) { |
167 | try { |
168 | *token = Token{TokenType::kInteger, IntImm(DataType::Int(64), std::stoll(to_parse))}; |
169 | } catch (const std::invalid_argument& e) { |
170 | LOG(WARNING) << "ValueError: Invalid argument to std::stoll: " << to_parse |
171 | << ". Details: " << e.what() << ". Switching to std::stod now." ; |
172 | is_float = true; |
173 | } catch (const std::out_of_range& e) { |
174 | LOG(WARNING) << "ValueError: Out-of-range for std::stoll: " << to_parse |
175 | << ". Details: " << e.what() << ". Switching to std::stod now." ; |
176 | is_float = true; |
177 | } |
178 | } |
179 | if (is_float) { |
180 | try { |
181 | *token = Token{TokenType::kFloat, FloatImm(DataType::Float(64), std::stod(to_parse))}; |
182 | } catch (const std::invalid_argument& e) { |
183 | LOG(INFO) << "ValueError: Invalid argument to std::stod: " << to_parse |
184 | << ". Details: " << e.what(); |
185 | } catch (const std::out_of_range& e) { |
186 | LOG(INFO) << "ValueError: Out-of-range for std::stod: " << to_parse |
187 | << ". Details: " << e.what(); |
188 | } |
189 | } |
190 | return true; |
191 | } |
192 | |
193 | bool NextString(Token* token) { |
194 | if (cur_ == end_ || *cur_ != '"') return false; |
195 | ++cur_; |
196 | std::string str; |
197 | for (; cur_ != end_ && *cur_ != '\"'; ++cur_) { |
198 | if (*cur_ != '\\') { |
199 | str.push_back(*cur_); |
200 | continue; |
201 | } |
202 | ++cur_; |
203 | if (cur_ == end_) { |
204 | LOG(FATAL) << "ValueError: Unexpected end of string: \\" ; |
205 | throw; |
206 | } |
207 | switch (*cur_) { |
208 | case '\"': |
209 | str.push_back('\"'); |
210 | break; |
211 | case '\\': |
212 | str.push_back('\\'); |
213 | break; |
214 | case '/': |
215 | str.push_back('/'); |
216 | break; |
217 | case 'b': |
218 | str.push_back('\b'); |
219 | break; |
220 | case 'f': |
221 | str.push_back('\f'); |
222 | break; |
223 | case 'n': |
224 | str.push_back('\n'); |
225 | break; |
226 | case 'r': |
227 | str.push_back('\r'); |
228 | break; |
229 | case 't': |
230 | str.push_back('\t'); |
231 | break; |
232 | default: |
233 | LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_; |
234 | } |
235 | } |
236 | if (cur_ == end_) { |
237 | LOG(FATAL) << "ValueError: Unexpected end of string" ; |
238 | } |
239 | ++cur_; |
240 | *token = Token{TokenType::kString, String(str)}; |
241 | return true; |
242 | } |
243 | |
244 | bool NextLiteral(char c) { |
245 | if (cur_ != end_ && *cur_ == c) { |
246 | ++cur_; |
247 | return true; |
248 | } |
249 | return false; |
250 | } |
251 | |
252 | bool NextLiteral(const char* str, int len) { |
253 | if (cur_ + len <= end_ && std::strncmp(cur_, str, len) == 0) { |
254 | cur_ += len; |
255 | return true; |
256 | } |
257 | return false; |
258 | } |
259 | /*! \brief The current pointer */ |
260 | const char* cur_; |
261 | /*! \brief End of the string */ |
262 | const char* end_; |
263 | |
264 | friend class JSONParser; |
265 | }; |
266 | |
267 | class JSONParser { |
268 | public: |
269 | using TokenType = JSONTokenizer::TokenType; |
270 | using Token = JSONTokenizer::Token; |
271 | |
272 | explicit JSONParser(const char* st, const char* ed) : tokenizer_(st, ed) {} |
273 | |
274 | ObjectRef Get() { |
275 | Token token = tokenizer_.Next(); |
276 | if (token.type == TokenType::kEOF) { |
277 | return ObjectRef(nullptr); |
278 | } |
279 | return ParseObject(std::move(token)); |
280 | } |
281 | |
282 | private: |
283 | ObjectRef ParseObject(Token token) { |
284 | switch (token.type) { |
285 | case TokenType::kNull: |
286 | return ObjectRef(nullptr); |
287 | case TokenType::kTrue: |
288 | return Bool(true); |
289 | case TokenType::kFalse: |
290 | return Bool(false); |
291 | case TokenType::kLeftSquare: |
292 | return ParseArray(); |
293 | case TokenType::kLeftCurly: |
294 | return ParseDict(); |
295 | case TokenType::kString: |
296 | case TokenType::kInteger: |
297 | case TokenType::kFloat: |
298 | return token.value; |
299 | case TokenType::kRightSquare: |
300 | LOG(FATAL) << "ValueError: Unexpected token: ]" ; |
301 | case TokenType::kRightCurly: |
302 | LOG(FATAL) << "ValueError: Unexpected token: }" ; |
303 | case TokenType::kComma: |
304 | LOG(FATAL) << "ValueError: Unexpected token: ," ; |
305 | case TokenType::kColon: |
306 | LOG(FATAL) << "ValueError: Unexpected token: :" ; |
307 | case TokenType::kEOF: |
308 | LOG(FATAL) << "ValueError: Unexpected EOF" ; |
309 | default: |
310 | throw; |
311 | } |
312 | } |
313 | |
314 | Array<ObjectRef> ParseArray() { |
315 | bool is_first = true; |
316 | Array<ObjectRef> results; |
317 | for (;;) { |
318 | Token token; |
319 | if (is_first) { |
320 | is_first = false; |
321 | token = Token{TokenType::kComma}; |
322 | } else { |
323 | token = tokenizer_.Next(); |
324 | } |
325 | // Three cases overall: |
326 | // - Case 1. 1 token: "]" |
327 | // - Case 2. 2 tokens: ",", "]" |
328 | // - Case 3. 2 tokens: ",", "obj" |
329 | if (token.type == TokenType::kRightSquare) { // Case 1 |
330 | break; |
331 | } else if (token.type == TokenType::kComma) { |
332 | token = tokenizer_.Next(); |
333 | if (token.type == TokenType::kRightSquare) { // Case 2 |
334 | break; |
335 | } |
336 | // Case 3 |
337 | results.push_back(ParseObject(std::move(token))); |
338 | continue; |
339 | } else { |
340 | LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_; |
341 | } |
342 | } |
343 | return results; |
344 | } |
345 | |
346 | Map<String, ObjectRef> ParseDict() { |
347 | bool is_first = true; |
348 | Map<String, ObjectRef> results; |
349 | for (;;) { |
350 | Token token; |
351 | if (is_first) { |
352 | is_first = false; |
353 | token = Token{TokenType::kComma}; |
354 | } else { |
355 | token = tokenizer_.Next(); |
356 | } |
357 | // Three cases overall: |
358 | // - Case 1. 1 token: "}" |
359 | // - Case 2. 2 tokens: ",", "}" |
360 | // - Case 3. 2 tokens: ",", "key", ":", "value" |
361 | if (token.type == TokenType::kRightCurly) { // Case 1 |
362 | break; |
363 | } else if (token.type == TokenType::kComma) { |
364 | token = tokenizer_.Next(); |
365 | if (token.type == TokenType::kRightCurly) { // Case 2 |
366 | break; |
367 | } |
368 | // Case 3 |
369 | ObjectRef key = ParseObject(std::move(token)); |
370 | ICHECK(key->IsInstance<StringObj>()) |
371 | << "ValueError: key must be a string, but gets: " << key; |
372 | token = tokenizer_.Next(); |
373 | CHECK(token.type == TokenType::kColon) |
374 | << "ValueError: Unexpected token before: " << tokenizer_.cur_; |
375 | ObjectRef value = ParseObject(tokenizer_.Next()); |
376 | results.Set(Downcast<String>(key), value); |
377 | continue; |
378 | } else { |
379 | LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_; |
380 | } |
381 | } |
382 | return results; |
383 | } |
384 | |
385 | JSONTokenizer tokenizer_; |
386 | }; |
387 | |
388 | ObjectRef JSONLoads(std::string str) { |
389 | const char* st = str.c_str(); |
390 | const char* ed = st + str.length(); |
391 | return JSONParser(st, ed).Get(); |
392 | } |
393 | |
394 | } // namespace meta_schedule |
395 | } // namespace tvm |
396 | |