1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file parser.h
22 * \brief A parser for TVM IR.
23 */
24#ifndef TVM_PARSER_TOKENIZER_H_
25#define TVM_PARSER_TOKENIZER_H_
26
27#include <tvm/node/serialization.h>
28#include <tvm/runtime/object.h>
29
30#include <fstream>
31#include <limits>
32#include <string>
33#include <unordered_map>
34#include <utility>
35#include <vector>
36
37#include "../support/scalars.h"
38#include "./meta_ref.h"
39#include "./token.h"
40
41namespace tvm {
42namespace parser {
43
44using namespace runtime;
45
46// trim from start (in place)
47static inline void ltrim(std::string& s) { // NOLINT(*)
48 s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); }));
49}
50
51// trim from end (in place)
52static inline void rtrim(std::string& s) { // NOLINT(*)
53 s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(),
54 s.end());
55}
56
57bool IsDigit(char c) { return '0' <= c && c <= '9'; }
58
59bool IsWhitespace(char c) { return ' ' == c || c == '\t' || c == '\n'; }
60
61bool IsNumeric(char c) {
62 return (IsDigit(c) || c == '.' || c == 'e' || c == '-' || c == '+' || c == 'E') &&
63 !IsWhitespace(c);
64}
65
66bool IsIdentLetter(char c) {
67 return '_' == c || c == '/' || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z');
68}
69
70bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); }
71
72static std::unordered_map<std::string, TokenType> KEYWORD_TABLE = {
73 {"let", TokenType::kLet}, {"fn", TokenType::kFn},
74 {"def", TokenType::kDefn}, {"if", TokenType::kIf},
75 {"else", TokenType::kElse}, {"type", TokenType::kTypeDef},
76 {"match", TokenType::kMatch}, {"extern", TokenType::kExtern},
77 {"free_var", TokenType::kFreeVar}, {"ref", TokenType::kRef},
78 {"ref_read", TokenType::kRefRead}, {"ref_write", TokenType::kRefWrite}};
79
80struct Tokenizer {
81 DiagnosticContext diag_ctx;
82 const SourceName& source_name;
83
84 size_t pos;
85 int col;
86 int line;
87 char next_char;
88 String source;
89 std::vector<Token> tokens;
90
91 char Next() {
92 char c = this->source.at(this->pos);
93 if (c == '\n') {
94 this->line += 1;
95 this->col = 1;
96 } else {
97 this->col += 1;
98 }
99 pos += 1;
100 return c;
101 }
102
103 bool More() { return this->pos < this->source.size(); }
104
105 char Peek() {
106 ICHECK(pos < this->source.size());
107 return this->source.at(this->pos);
108 }
109
110 Token NewToken(TokenType token_type, ObjectRef data = ObjectRef(), int lines = 0, int cols = 1) {
111 auto span =
112 Span(this->source_name, this->line, this->line + lines, this->col, this->col + cols);
113 return Token(span, token_type, data);
114 }
115
116 Span SpanFrom(int line, int column) {
117 int end_line = this->line;
118 int end_column = this->col;
119 return Span(this->source_name, line, end_line, column, end_column);
120 }
121
122 enum CommentParserState {
123 Proceed,
124 Forward,
125 Backward,
126 };
127
128 void MatchComment(std::string* buffer) {
129 // We only invoke this after we have matched the first start
130 // token assume, we are proceeding the parse forward with
131 // nesting = 1.
132 //
133 // When we are done we should be at nesting zero and be
134 // in the stop state.
135 CommentParserState state = CommentParserState::Proceed;
136 int nesting = 1;
137
138 while (More()) {
139 switch (state) {
140 case CommentParserState::Proceed: {
141 if (Peek() == '/') {
142 state = CommentParserState::Forward;
143 } else if (Peek() == '*') {
144 state = CommentParserState::Backward;
145 }
146 buffer->operator+=(Next());
147 continue;
148 }
149 case CommentParserState::Forward: {
150 if (Peek() == '*') {
151 nesting += 1;
152 buffer->operator+=(Next());
153 }
154 state = CommentParserState::Proceed;
155 continue;
156 }
157 case CommentParserState::Backward: {
158 if (Peek() == '/') {
159 nesting -= 1;
160 if (nesting == 0) {
161 Next();
162 buffer->pop_back();
163 return;
164 }
165 }
166
167 buffer->operator+=(Next());
168 state = CommentParserState::Proceed;
169 continue;
170 }
171 }
172 }
173 }
174
175 Token ParseNumber(bool is_pos, bool is_float, std::string number) {
176 ICHECK(number.size() > 0) << "an empty string is an invalid number";
177
178 Token token = NewToken(is_float ? TokenType::kFloat : TokenType::kInteger);
179 size_t suffix_pos = number.rfind(is_float ? 'f' : 'i');
180 if (suffix_pos == std::string::npos) {
181 suffix_pos = number.size();
182 }
183 std::string literal_text = number.substr(0, suffix_pos);
184 std::string suffix;
185 if (suffix_pos < number.size()) {
186 suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos);
187 }
188 int width = 32;
189
190 if (suffix.size()) {
191 try {
192 width = std::stoi(suffix);
193 } catch (const std::invalid_argument& err) {
194 this->diag_ctx.Emit(Diagnostic::Error(token->span)
195 << "invalid numeric suffix `" << suffix << "`");
196 } catch (const std::out_of_range& err) {
197 this->diag_ctx.Emit(Diagnostic::Error(token->span)
198 << "invalid numeric suffix `" << suffix << "`");
199 }
200 }
201
202 if (is_float) {
203 double value = 0.0;
204 size_t index = 0;
205 try {
206 value = stod(literal_text, &index);
207 } catch (const std::invalid_argument& err) {
208 this->diag_ctx.Emit(Diagnostic::Error(token->span)
209 << "invalid floating point number `" << literal_text << "`");
210 } catch (const std::out_of_range& err) {
211 this->diag_ctx.Emit(Diagnostic::Error(token->span)
212 << "invalid floating point number `" << literal_text << "`");
213 }
214 if (index < literal_text.size()) {
215 this->diag_ctx.Emit(Diagnostic::Error(token->span)
216 << "invalid floating point number `" << literal_text << "`");
217 }
218 value = is_pos ? value : -value;
219 token->data = support::ValueToFloatImm(value, width);
220 if (!token->data.defined()) {
221 this->diag_ctx.Emit(Diagnostic::Error(token->span)
222 << "floating point number `" << literal_text
223 << "` unrepresentable in width " << width);
224 token->data = support::ValueToFloatImm(0.0, width);
225 }
226 } else {
227 int64_t value = 0;
228 size_t index = 0;
229 try {
230 value = std::stoll(literal_text, &index);
231 } catch (const std::invalid_argument& err) {
232 this->diag_ctx.Emit(Diagnostic::Error(token->span)
233 << "invalid integer number `" << literal_text << "`");
234 } catch (const std::out_of_range& err) {
235 this->diag_ctx.Emit(Diagnostic::Error(token->span)
236 << "invalid integer number `" << literal_text << "`");
237 }
238 if (index < literal_text.size()) {
239 this->diag_ctx.Emit(Diagnostic::Error(token->span)
240 << "invalid integer number `" << literal_text << "`");
241 }
242 value = is_pos ? value : -value;
243 token->data = support::ValueToIntImm(value, width);
244 if (!token->data.defined() && suffix.empty()) {
245 // Without any i suffix the legacy behavior was to default to int64 if out of range
246 // for int32.
247 width = 64;
248 token->data = support::ValueToIntImm(value, width);
249 }
250 if (!token->data.defined()) {
251 this->diag_ctx.Emit(Diagnostic::Error(token->span)
252 << "integer number `" << literal_text << "` unrepresentable in width "
253 << width);
254 token->data = support::ValueToIntImm(0, width);
255 }
256 }
257
258 return token;
259 }
260
261 Token ParseNumber(bool is_pos) {
262 std::stringstream ss;
263 while (More() && IsNumeric(Peek())) {
264 ss << Next();
265 }
266
267 bool is_float = false;
268 if (More() && (Peek() == 'f' || Peek() == 'i')) {
269 is_float = Peek() == 'f';
270 // Capture trailing width suffix
271 ss << Next();
272 while (More() && IsNumeric(Peek())) {
273 ss << Next();
274 }
275 }
276 return ParseNumber(is_pos, is_float, ss.str());
277 }
278
279 bool MatchString(const std::string& string) {
280 int start = this->pos;
281
282 for (auto c : string) {
283 if (Peek() != c) {
284 this->pos = start;
285 return false;
286 } else {
287 Next();
288 }
289 }
290
291 return true;
292 }
293
294 Token TokenizeMetaRef() {
295 int line = this->line;
296 int column = this->col;
297
298 std::stringstream type_key;
299 while (More() && Peek() != ']') {
300 type_key << Next();
301 }
302 ICHECK_EQ(Peek(), ']');
303 Next();
304
305 ICHECK_EQ(Peek(), '[');
306 Next();
307 std::stringstream str_index;
308 while (More() && Peek() != ']') {
309 str_index << Next();
310 }
311 ICHECK_EQ(Peek(), ']');
312 Next();
313 // todo: add error handling around bad indices
314 auto index = ParseNumber(true, false, str_index.str()).ToNumber();
315 auto span = SpanFrom(line, column);
316 return Token(span, TokenType::kMetaReference, MetaRef(type_key.str(), index));
317 }
318
319 Token TokenizeAttr() {
320 int line = this->line;
321 int column = this->col;
322 Next();
323 if (Peek() == '[') {
324 Next();
325 std::stringstream raw_attribute;
326
327 while (More() && Peek() != ']') {
328 raw_attribute << Next();
329 }
330
331 ICHECK_EQ(Next(), ']');
332
333 auto attribute = raw_attribute.str();
334 // Clean up the white-space on both sides.
335 ltrim(attribute);
336 rtrim(attribute);
337
338 // Metadata can only appear at the bottom of a file and goes to EOF.
339 if (attribute == "metadata") {
340 std::stringstream metadata;
341 while (More()) {
342 metadata << Next();
343 }
344 ObjectRef metadata_map = tvm::LoadJSON(metadata.str());
345 auto span = SpanFrom(line, column);
346 return Token(span, TokenType::kMetadata, metadata_map);
347 }
348 if (attribute.rfind("version", 0) == 0) {
349 std::string version = attribute.substr(attribute.find("=") + 1);
350 ltrim(version);
351 rtrim(version);
352 auto span = SpanFrom(line, column);
353 return Token(span, TokenType::kVersion, tvm::String(version));
354 } else {
355 // TOOD(@jroesch): maybe make this a warning an continue parsing?
356 auto span = SpanFrom(line, column);
357 this->diag_ctx.EmitFatal(Diagnostic::Error(span) << "unsupported attribute " << attribute);
358 return Token();
359 }
360 } else {
361 auto span = SpanFrom(line, column);
362 this->diag_ctx
363 .EmitFatal(Diagnostic::Error(span)
364 << "`#` denotes the start of an attribute can only be followed by `[`"
365 << " found `" << Peek() << "`");
366 return Token();
367 }
368 }
369
370 inline Token TokenizeOnce() {
371 int line = this->line;
372 int col = this->col;
373 auto next = Peek();
374 VLOG(9) << "tvm::parser::TokenizeOnce: next=" << next;
375 if (next == '\n') {
376 auto token = NewToken(TokenType::kNewline);
377 Next();
378 return token;
379 } else if (next == '\r') {
380 Next();
381 if (More() && Peek() == '\n') {
382 auto token = NewToken(TokenType::kNewline);
383 return token;
384 } else {
385 auto span = SpanFrom(line, col);
386 this->diag_ctx.EmitFatal(
387 Diagnostic::Error(span)
388 << "\\r carriage returns must be followed by a \\n in the TVM text format");
389 return Token();
390 }
391 } else if (next == '"') {
392 // TODO(@jroesch): Properly tokenize escape sequences in strings.
393 // see https://github.com/apache/tvm/issues/6153.
394 Next();
395 std::stringstream string_content;
396 while (More() && Peek() != '"') {
397 string_content << Next();
398 }
399 Next();
400 return NewToken(TokenType::kStringLiteral, tvm::String(string_content.str()));
401 } else if (IsWhitespace(next)) {
402 auto token = NewToken(TokenType::kWhitespace);
403 Next();
404 return token;
405 } else if (next == '-') {
406 int negs = 0;
407 while (More() && Peek() == '-') {
408 Next();
409 negs++;
410 }
411 bool is_neg = negs % 2 == 1;
412 if (More() && IsDigit(Peek())) {
413 return ParseNumber(!is_neg);
414 } else if (More() && MatchString("inff")) {
415 return ParseNumber(!is_neg, true, "inff");
416 } else {
417 // If there isn't a number right after either,
418 // this is really slow for lexing, should replace
419 // with multi-token return or something.
420 pos = pos - (negs - 1);
421 return NewToken(TokenType::kMinus);
422 }
423 } else if (IsDigit(next)) {
424 return ParseNumber(true);
425 } else if (MatchString("inff")) {
426 return ParseNumber(true, true, "inff");
427 } else if (next == '.') {
428 auto token = NewToken(TokenType::kPeriod);
429 Next();
430 return token;
431 } else if (next == ',') {
432 auto token = NewToken(TokenType::kComma);
433 Next();
434 return token;
435 } else if (next == '=') {
436 auto token = NewToken(TokenType::kEqual);
437 Next();
438 return token;
439 } else if (next == ';') {
440 auto token = NewToken(TokenType::kSemicolon);
441 Next();
442 return token;
443 } else if (next == ':') {
444 auto token = NewToken(TokenType::kColon);
445 Next();
446 return token;
447 } else if (next == '(') {
448 auto token = NewToken(TokenType::kOpenParen);
449 Next();
450 return token;
451 } else if (next == ')') {
452 auto token = NewToken(TokenType::kCloseParen);
453 Next();
454 return token;
455 } else if (next == '+') {
456 auto token = NewToken(TokenType::kPlus);
457 Next();
458 return token;
459 } else if (next == '*') {
460 auto token = NewToken(TokenType::kStar);
461 Next();
462 return token;
463 } else if (next == '<') {
464 auto token = NewToken(TokenType::kLAngle);
465 Next();
466 return token;
467 } else if (next == '>') {
468 auto token = NewToken(TokenType::kRAngle);
469 Next();
470 return token;
471 } else if (next == '{') {
472 auto token = NewToken(TokenType::kLCurly);
473 Next();
474 return token;
475 } else if (next == '}') {
476 auto token = NewToken(TokenType::kRCurly);
477 Next();
478 return token;
479 } else if (next == '[') {
480 auto token = NewToken(TokenType::kLSquare);
481 Next();
482 return token;
483 } else if (next == ']') {
484 auto token = NewToken(TokenType::kRSquare);
485 Next();
486 return token;
487 } else if (next == '!') {
488 auto token = NewToken(TokenType::kBang);
489 Next();
490 return token;
491 } else if (next == '@') {
492 auto token = NewToken(TokenType::kAt);
493 Next();
494 return token;
495 } else if (next == '?') {
496 auto token = NewToken(TokenType::kQuestion);
497 Next();
498 return token;
499 } else if (MatchString("meta[")) {
500 return TokenizeMetaRef();
501 } else if (next == '#') {
502 return TokenizeAttr();
503 } else if (next == '%') {
504 auto token = NewToken(TokenType::kPercent);
505 Next();
506
507 std::stringstream number;
508 while (More() && IsDigit(Peek())) {
509 number << Next();
510 }
511
512 auto number_str = number.str();
513 if (number_str.size()) {
514 auto num_tok = ParseNumber(true, false, number_str);
515 auto span = SpanFrom(token->span->line, token->span->column);
516 token = Token(span, TokenType::kGraph, num_tok->data);
517 }
518
519 return token;
520 } else if (next == '/') {
521 Next();
522 if (Peek() == '/') {
523 auto token = NewToken(TokenType::kLineComment);
524 // Consume the /
525 Next();
526 std::stringstream comment;
527 while (More() && Peek() != '\n') {
528 comment << Next();
529 }
530 token->data = tvm::String(comment.str());
531 return token;
532 } else if (Peek() == '*') {
533 // Eat the first /* pair before entering the state machine.
534 Next();
535 std::string comment;
536 MatchComment(&comment);
537 auto token = NewToken(TokenType::kComment, tvm::String(comment));
538 return token;
539 } else {
540 return NewToken(TokenType::kDivision);
541 }
542 } else if (IsIdentLetter(next)) {
543 std::stringstream ss;
544 // Due the below code we need to patch
545 // the line/col info to the start of
546 // token.
547 int line = this->line;
548 int col = this->col;
549
550 while (More() && IsIdent(Peek())) {
551 ss << Next();
552 }
553
554 std::string keyword = ss.str();
555 auto it = KEYWORD_TABLE.find(keyword);
556
557 TokenType token_type;
558 if (it != KEYWORD_TABLE.end()) {
559 token_type = it->second;
560
561 if (token_type == TokenType::kMatch) {
562 if (More() && Peek() == '?') {
563 Next();
564 token_type = TokenType::kPartialMatch;
565 }
566 }
567 } else {
568 token_type = TokenType::kIdentifier;
569 }
570
571 auto span = SpanFrom(line, col);
572 return Token(span, token_type, tvm::String(ss.str()));
573 } else {
574 std::stringstream ss;
575 while (More() && !IsWhitespace(Peek())) {
576 ss << Next();
577 }
578 auto token = NewToken(TokenType::kUnknown);
579 token->data = tvm::String(ss.str());
580 return token;
581 }
582 }
583
584 void Tokenize() {
585 VLOG(9) << "tvm::parser::Tokenize";
586 while (this->More()) {
587 auto token = TokenizeOnce();
588 ICHECK(token.defined());
589 this->tokens.push_back(token);
590 }
591 this->tokens.push_back(NewToken(TokenType::kEndOfFile));
592 }
593
594 explicit Tokenizer(const DiagnosticContext& ctx, const Source& source)
595 : diag_ctx(ctx),
596 source_name(source->source_name),
597 pos(0),
598 col(1),
599 line(1),
600 source(source->source),
601 tokens() {}
602};
603
604std::vector<Token> Condense(const std::vector<Token>& tokens, Token* table) {
605 std::vector<Token> out;
606 bool found_metadata = false;
607
608 for (size_t i = 0; i < tokens.size(); i++) {
609 auto current = tokens.at(i);
610 switch (current->token_type) {
611 case TokenType::kMetadata: {
612 if (!found_metadata) {
613 found_metadata = true;
614 *table = current;
615 } else {
616 LOG(FATAL) << "duplicate metadata section";
617 }
618 continue;
619 }
620 case TokenType::kPercent: {
621 auto next = tokens.at(i + 1);
622 if (next->token_type == TokenType::kIdentifier) {
623 // Match this token.
624 i += 1;
625 // TODO(@jroesch): merge spans
626 auto tok = Token(current->span, TokenType::kLocal, next->data);
627 ICHECK(tok.defined());
628 out.push_back(tok);
629 } else if (next->token_type == TokenType::kInteger) {
630 i += 1;
631 auto tok = Token(current->span, TokenType::kGraph, next->data);
632 ICHECK(tok.defined());
633 out.push_back(tok);
634 } else {
635 ICHECK(current.defined());
636 out.push_back(current);
637 }
638 continue;
639 }
640 case TokenType::kAt: {
641 auto next = tokens.at(i + 1);
642 if (next->token_type == TokenType::kIdentifier) {
643 // Match this token.
644 i += 1;
645 // TODO(@jroesch): merge spans
646 auto tok = Token(current->span, TokenType::kGlobal, next->data);
647 ICHECK(tok.defined());
648 out.push_back(tok);
649 } else {
650 ICHECK(current.defined());
651 out.push_back(current);
652 }
653 continue;
654 }
655 case TokenType::kIdentifier: {
656 std::string str = Downcast<tvm::String>(current->data);
657 Token tok;
658 // TODO(@jroesch): merge spans
659 if (str == "True") {
660 auto data = tvm::Integer(1);
661 tok = Token(current->span, TokenType::kBoolean, data);
662 } else if (str == "False") {
663 auto data = tvm::Integer(0);
664 tok = Token(current->span, TokenType::kBoolean, data);
665 } else if (str == "_") {
666 tok = Token(current->span, TokenType::kUnderscore);
667 } else {
668 tok = current;
669 }
670 out.push_back(tok);
671 continue;
672 }
673 default: {
674 out.push_back(current);
675 continue;
676 }
677 }
678 }
679
680 return out;
681}
682
683std::pair<std::vector<Token>, Token> Tokenize(const DiagnosticContext& ctx, const Source& source) {
684 auto tokenizer = Tokenizer(ctx, source);
685 tokenizer.Tokenize();
686 Token meta_table(Span(), TokenType::kUnknown, ObjectRef());
687 auto tokens = Condense(tokenizer.tokens, &meta_table);
688 for (auto token : tokens) {
689 ICHECK(token.defined());
690 }
691 return {tokens, meta_table};
692}
693
694} // namespace parser
695} // namespace tvm
696
697#endif // TVM_PARSER_TOKENIZER_H_
698