1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <iomanip>
20#include <sstream>
21#include <vector>
22
23#include "../../support/str_escape.h"
24#include "../utils.h"
25
26namespace tvm {
27namespace meta_schedule {
28
29void JSONDumps(ObjectRef json_obj, std::ostringstream& os) {
30 if (!json_obj.defined()) {
31 os << "null";
32 } else if (const auto* int_imm = json_obj.as<IntImmNode>()) {
33 if (int_imm->dtype == DataType::Bool()) {
34 if (int_imm->value) {
35 os << "true";
36 } else {
37 os << "false";
38 }
39 } else {
40 os << int_imm->value;
41 }
42 } else if (const auto* float_imm = json_obj.as<FloatImmNode>()) {
43 os << std::setprecision(20) << float_imm->value;
44 } else if (const auto* str = json_obj.as<runtime::StringObj>()) {
45 os << '"' << support::StrEscape(str->data, str->size) << '"';
46 } else if (const auto* array = json_obj.as<runtime::ArrayNode>()) {
47 os << "[";
48 int n = array->size();
49 for (int i = 0; i < n; ++i) {
50 if (i != 0) {
51 os << ",";
52 }
53 JSONDumps(array->at(i), os);
54 }
55 os << "]";
56 } else if (const auto* dict = json_obj.as<runtime::MapNode>()) {
57 int n = dict->size();
58 std::vector<std::pair<String, ObjectRef>> key_values;
59 key_values.reserve(n);
60 for (const auto& kv : *dict) {
61 if (const auto* k = kv.first.as<StringObj>()) {
62 key_values.emplace_back(GetRef<String>(k), kv.second);
63 } else {
64 LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: "
65 << kv.first->GetTypeKey();
66 }
67 }
68 std::sort(key_values.begin(), key_values.end());
69 os << "{";
70 for (int i = 0; i < n; ++i) {
71 const auto& kv = key_values[i];
72 if (i != 0) {
73 os << ",";
74 }
75 os << '"' << support::StrEscape(kv.first->data, kv.first->size) << '"';
76 os << ":";
77 JSONDumps(kv.second, os);
78 }
79 os << "}";
80 } else if (json_obj->IsInstance<tir::IndexMapNode>()) {
81 JSONDumps(String(SaveJSON(json_obj)), os);
82 } else {
83 LOG(FATAL) << "TypeError: Unsupported type in JSON object: " << json_obj->GetTypeKey();
84 }
85}
86
87std::string JSONDumps(ObjectRef json_obj) {
88 std::ostringstream os;
89 JSONDumps(json_obj, os);
90 return os.str();
91}
92
93class JSONTokenizer {
94 public:
95 enum class TokenType : int32_t {
96 kEOF = 0, // end of file
97 kNull = 1, // null
98 kTrue = 2, // true
99 kFalse = 3, // false
100 kLeftSquare = 4, // [
101 kRightSquare = 5, // ]
102 kLeftCurly = 6, // {
103 kRightCurly = 7, // }
104 kComma = 8, // ,
105 kColon = 9, // :
106 kInteger = 10, // integers
107 kFloat = 11, // floating point numbers
108 kString = 12, // string
109 };
110
111 struct Token {
112 TokenType type;
113 ObjectRef value{nullptr};
114 };
115
116 explicit JSONTokenizer(const char* st, const char* ed) : cur_(st), end_(ed) {}
117
118 Token Next() {
119 for (; cur_ != end_ && std::isspace(*cur_); ++cur_) {
120 }
121 if (cur_ == end_) return Token{TokenType::kEOF};
122 if (NextLeftSquare()) return Token{TokenType::kLeftSquare};
123 if (NextRightSquare()) return Token{TokenType::kRightSquare};
124 if (NextLeftCurly()) return Token{TokenType::kLeftCurly};
125 if (NextRightCurly()) return Token{TokenType::kRightCurly};
126 if (NextComma()) return Token{TokenType::kComma};
127 if (NextColon()) return Token{TokenType::kColon};
128 if (NextNull()) return Token{TokenType::kNull};
129 if (NextTrue()) return Token{TokenType::kTrue};
130 if (NextFalse()) return Token{TokenType::kFalse};
131 Token token;
132 if (NextString(&token)) return token;
133 if (NextNumber(&token)) return token;
134 LOG(FATAL) << "ValueError: Cannot tokenize: " << std::string(cur_, end_);
135 throw;
136 }
137
138 private:
139 bool NextLeftSquare() { return NextLiteral('['); }
140 bool NextRightSquare() { return NextLiteral(']'); }
141 bool NextLeftCurly() { return NextLiteral('{'); }
142 bool NextRightCurly() { return NextLiteral('}'); }
143 bool NextComma() { return NextLiteral(','); }
144 bool NextColon() { return NextLiteral(':'); }
145 bool NextNull() { return NextLiteral("null", 4); }
146 bool NextTrue() { return NextLiteral("true", 4); }
147 bool NextFalse() { return NextLiteral("false", 5); }
148
149 bool NextNumber(Token* token) {
150 using runtime::DataType;
151 bool is_float = false;
152 const char* st = cur_;
153 for (; cur_ != end_; ++cur_) {
154 if (std::isdigit(*cur_) || *cur_ == '+' || *cur_ == '-') {
155 continue;
156 } else if (*cur_ == '.' || *cur_ == 'e' || *cur_ == 'E') {
157 is_float = true;
158 } else {
159 break;
160 }
161 }
162 if (st == cur_) {
163 return false;
164 }
165 std::string to_parse(st, cur_);
166 if (!is_float) {
167 try {
168 *token = Token{TokenType::kInteger, IntImm(DataType::Int(64), std::stoll(to_parse))};
169 } catch (const std::invalid_argument& e) {
170 LOG(WARNING) << "ValueError: Invalid argument to std::stoll: " << to_parse
171 << ". Details: " << e.what() << ". Switching to std::stod now.";
172 is_float = true;
173 } catch (const std::out_of_range& e) {
174 LOG(WARNING) << "ValueError: Out-of-range for std::stoll: " << to_parse
175 << ". Details: " << e.what() << ". Switching to std::stod now.";
176 is_float = true;
177 }
178 }
179 if (is_float) {
180 try {
181 *token = Token{TokenType::kFloat, FloatImm(DataType::Float(64), std::stod(to_parse))};
182 } catch (const std::invalid_argument& e) {
183 LOG(INFO) << "ValueError: Invalid argument to std::stod: " << to_parse
184 << ". Details: " << e.what();
185 } catch (const std::out_of_range& e) {
186 LOG(INFO) << "ValueError: Out-of-range for std::stod: " << to_parse
187 << ". Details: " << e.what();
188 }
189 }
190 return true;
191 }
192
193 bool NextString(Token* token) {
194 if (cur_ == end_ || *cur_ != '"') return false;
195 ++cur_;
196 std::string str;
197 for (; cur_ != end_ && *cur_ != '\"'; ++cur_) {
198 if (*cur_ != '\\') {
199 str.push_back(*cur_);
200 continue;
201 }
202 ++cur_;
203 if (cur_ == end_) {
204 LOG(FATAL) << "ValueError: Unexpected end of string: \\";
205 throw;
206 }
207 switch (*cur_) {
208 case '\"':
209 str.push_back('\"');
210 break;
211 case '\\':
212 str.push_back('\\');
213 break;
214 case '/':
215 str.push_back('/');
216 break;
217 case 'b':
218 str.push_back('\b');
219 break;
220 case 'f':
221 str.push_back('\f');
222 break;
223 case 'n':
224 str.push_back('\n');
225 break;
226 case 'r':
227 str.push_back('\r');
228 break;
229 case 't':
230 str.push_back('\t');
231 break;
232 default:
233 LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_;
234 }
235 }
236 if (cur_ == end_) {
237 LOG(FATAL) << "ValueError: Unexpected end of string";
238 }
239 ++cur_;
240 *token = Token{TokenType::kString, String(str)};
241 return true;
242 }
243
244 bool NextLiteral(char c) {
245 if (cur_ != end_ && *cur_ == c) {
246 ++cur_;
247 return true;
248 }
249 return false;
250 }
251
252 bool NextLiteral(const char* str, int len) {
253 if (cur_ + len <= end_ && std::strncmp(cur_, str, len) == 0) {
254 cur_ += len;
255 return true;
256 }
257 return false;
258 }
259 /*! \brief The current pointer */
260 const char* cur_;
261 /*! \brief End of the string */
262 const char* end_;
263
264 friend class JSONParser;
265};
266
267class JSONParser {
268 public:
269 using TokenType = JSONTokenizer::TokenType;
270 using Token = JSONTokenizer::Token;
271
272 explicit JSONParser(const char* st, const char* ed) : tokenizer_(st, ed) {}
273
274 ObjectRef Get() {
275 Token token = tokenizer_.Next();
276 if (token.type == TokenType::kEOF) {
277 return ObjectRef(nullptr);
278 }
279 return ParseObject(std::move(token));
280 }
281
282 private:
283 ObjectRef ParseObject(Token token) {
284 switch (token.type) {
285 case TokenType::kNull:
286 return ObjectRef(nullptr);
287 case TokenType::kTrue:
288 return Bool(true);
289 case TokenType::kFalse:
290 return Bool(false);
291 case TokenType::kLeftSquare:
292 return ParseArray();
293 case TokenType::kLeftCurly:
294 return ParseDict();
295 case TokenType::kString:
296 case TokenType::kInteger:
297 case TokenType::kFloat:
298 return token.value;
299 case TokenType::kRightSquare:
300 LOG(FATAL) << "ValueError: Unexpected token: ]";
301 case TokenType::kRightCurly:
302 LOG(FATAL) << "ValueError: Unexpected token: }";
303 case TokenType::kComma:
304 LOG(FATAL) << "ValueError: Unexpected token: ,";
305 case TokenType::kColon:
306 LOG(FATAL) << "ValueError: Unexpected token: :";
307 case TokenType::kEOF:
308 LOG(FATAL) << "ValueError: Unexpected EOF";
309 default:
310 throw;
311 }
312 }
313
314 Array<ObjectRef> ParseArray() {
315 bool is_first = true;
316 Array<ObjectRef> results;
317 for (;;) {
318 Token token;
319 if (is_first) {
320 is_first = false;
321 token = Token{TokenType::kComma};
322 } else {
323 token = tokenizer_.Next();
324 }
325 // Three cases overall:
326 // - Case 1. 1 token: "]"
327 // - Case 2. 2 tokens: ",", "]"
328 // - Case 3. 2 tokens: ",", "obj"
329 if (token.type == TokenType::kRightSquare) { // Case 1
330 break;
331 } else if (token.type == TokenType::kComma) {
332 token = tokenizer_.Next();
333 if (token.type == TokenType::kRightSquare) { // Case 2
334 break;
335 }
336 // Case 3
337 results.push_back(ParseObject(std::move(token)));
338 continue;
339 } else {
340 LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_;
341 }
342 }
343 return results;
344 }
345
346 Map<String, ObjectRef> ParseDict() {
347 bool is_first = true;
348 Map<String, ObjectRef> results;
349 for (;;) {
350 Token token;
351 if (is_first) {
352 is_first = false;
353 token = Token{TokenType::kComma};
354 } else {
355 token = tokenizer_.Next();
356 }
357 // Three cases overall:
358 // - Case 1. 1 token: "}"
359 // - Case 2. 2 tokens: ",", "}"
360 // - Case 3. 2 tokens: ",", "key", ":", "value"
361 if (token.type == TokenType::kRightCurly) { // Case 1
362 break;
363 } else if (token.type == TokenType::kComma) {
364 token = tokenizer_.Next();
365 if (token.type == TokenType::kRightCurly) { // Case 2
366 break;
367 }
368 // Case 3
369 ObjectRef key = ParseObject(std::move(token));
370 ICHECK(key->IsInstance<StringObj>())
371 << "ValueError: key must be a string, but gets: " << key;
372 token = tokenizer_.Next();
373 CHECK(token.type == TokenType::kColon)
374 << "ValueError: Unexpected token before: " << tokenizer_.cur_;
375 ObjectRef value = ParseObject(tokenizer_.Next());
376 results.Set(Downcast<String>(key), value);
377 continue;
378 } else {
379 LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_;
380 }
381 }
382 return results;
383 }
384
385 JSONTokenizer tokenizer_;
386};
387
388ObjectRef JSONLoads(std::string str) {
389 const char* st = str.c_str();
390 const char* ed = st + str.length();
391 return JSONParser(st, ed).Get();
392}
393
394} // namespace meta_schedule
395} // namespace tvm
396