1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file parser.cc
22 * \brief A parser for TVM IR.
23 */
24#include <tvm/ir/module.h>
25#include <tvm/node/reflection.h>
26#include <tvm/parser/parser.h>
27#include <tvm/relay/adt.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/function.h>
30#include <tvm/relay/transform.h>
31#include <tvm/runtime/logging.h>
32#include <tvm/runtime/object.h>
33#include <tvm/runtime/registry.h>
34#include <tvm/target/virtual_device.h>
35
36#include <fstream>
37
38#include "../support/scalars.h"
39#include "./meta_ref.h"
40#include "./op_table.h"
41#include "./span_check.h"
42#include "./tokenizer.h"
43#include "tvm/runtime/builtin_fp16.h"
44
45namespace tvm {
46namespace parser {
47
48using namespace relay;
49using Expr = relay::Expr;
50
51/*! \brief The meta table maps from type key to a sequence of objects. */
52using MetaTable = Map<String, Array<ObjectRef>>;
53
54using tvm::transform::CreateModulePass;
55using tvm::transform::PassContext;
56
57/*! \brief A helper for passing around spans with data structures with
58 * no span field.
59 */
60template <typename T>
61struct Spanned {
62 T data;
63 Span span;
64
65 Spanned() = default;
66 Spanned(const Spanned<T>& other) = default;
67 Spanned(T data, Span span) : data(data), span(span) {}
68};
69
70/*! \brief A wrapper structure for capturing the result of parsing
71 * a global definition *before* we add it to the IRModule.
72 *
73 * This enables the parser to parse everything in one pass before
74 * constructing the IRModule.
75 */
76struct GlobalFunc {
77 GlobalVar global;
78 Function function;
79 GlobalFunc() : global(), function() {}
80 GlobalFunc(GlobalVar global, Function function) : global(global), function(function) {}
81 GlobalFunc(const GlobalFunc& gfunc) {
82 this->global = gfunc.global;
83 this->function = gfunc.function;
84 }
85};
86
87/*! \brief A wrapper structure for capturing all top-level definitions
88 * when parsing a module.
89 */
90struct Definitions {
91 /*! \brief The set of global functions. */
92 std::vector<GlobalFunc> funcs;
93 /*! \brief The set of type definitions. */
94 std::vector<TypeData> types;
95 // TODO(@jroesch): contain meta-table below
96};
97
98/*! \brief A structure representing the semantic versioning information
99 * for a Relay program.
100 */
101class SemVer {
102 public:
103 int major_version;
104 int minor_version;
105 int patch_version;
106
107 SemVer() : major_version(0), minor_version(0), patch_version(0) {}
108 SemVer(int major_version, int minor_version, int patch_version)
109 : major_version(major_version), minor_version(minor_version), patch_version(patch_version) {}
110 SemVer(const SemVer& other)
111 : major_version(other.major_version),
112 minor_version(other.minor_version),
113 patch_version(other.patch_version) {}
114};
115
116/*! \brief A simple wrapper around a mapping from raw string names
117 * to a TVM variable, type variable or other binder type.
118 */
119template <typename T>
120struct Scope {
121 /*! \brief The internal map. */
122 std::unordered_map<std::string, T> name_map;
123};
124
125/*! \brief A stack of scopes.
126 *
127 * In order to properly handle scoping we must maintain a stack of scopes.
128 *
129 * A stack allows users to write programs which contain repeated variable
130 * names and to properly handle both nested scopes and removal of variables
131 * when they go out of scope.
132 *
133 * This is the classic approach to lexical scoping.
134 */
135template <typename T>
136class ScopeStack {
137 private:
138 std::vector<Scope<T>> scope_stack;
139 std::unordered_map<std::string, T> free_vars;
140
141 public:
142 /*! \brief Adds a variable binding to the current scope. */
143 void Add(const std::string& name, const T& value) {
144 if (!this->scope_stack.size()) {
145 LOG(FATAL) << "internal issue";
146 }
147 this->scope_stack.back().name_map.insert({name, value});
148 }
149
150 void AddFreeVar(const std::string& name, const T& value) { free_vars.insert({name, value}); }
151
152 /*! \brief Looks up a variable name in the scope stack returning the matching variable
153 * in most recent scope. */
154 T Lookup(const std::string& name) {
155 for (auto scope = this->scope_stack.rbegin(); scope != this->scope_stack.rend(); ++scope) {
156 auto it = scope->name_map.find(name);
157 if (it != scope->name_map.end()) {
158 return it->second;
159 }
160 }
161
162 // Check if we bound a free variable declaration.
163 auto it = free_vars.find(name);
164 if (it != free_vars.end()) {
165 return it->second;
166 }
167
168 return T();
169 }
170
171 /*! \brief Adds a fresh scope. */
172 void PushStack() { this->scope_stack.push_back(Scope<T>()); }
173
174 /*! \brief Removes the most recent scope. */
175 void PopStack() { this->scope_stack.pop_back(); }
176};
177
178struct DuplicateKeyError : public Error {
179 explicit DuplicateKeyError(const std::string& msg) : Error(msg) {}
180};
181
182/*! \brief A table of interning strings as global function and type names. */
183template <typename T>
184struct InternTable {
185 /*! \brief The internal table mapping strings to a unique allocation. */
186 std::unordered_map<std::string, T> table;
187 DiagnosticContext* ctx;
188
189 /*! \brief Add the unique allocation. */
190 void Add(const std::string& name, const T& t) {
191 auto it = table.find(name);
192 if (it != table.end()) {
193 throw DuplicateKeyError("duplicate key name in intern table");
194 } else {
195 table.insert({name, t});
196 }
197 }
198
199 /*! \brief Return the unique allocation. */
200 Optional<T> Get(const std::string& name) const {
201 auto it = table.find(name);
202 if (it != table.end()) {
203 return Optional<T>(it->second);
204 } else {
205 return Optional<T>();
206 }
207 }
208};
209
210GlobalVar AddOrGet(InternTable<GlobalVar>* table, const std::string& name) {
211 auto var = table->Get(name);
212 if (var) {
213 return var.value();
214 } else {
215 auto gvar = GlobalVar(name);
216 table->Add(name, gvar);
217 return gvar;
218 }
219}
220
221GlobalTypeVar AddOrGet(InternTable<GlobalTypeVar>* table, const std::string& name,
222 TypeKind kind = TypeKind::kType) {
223 auto var = table->Get(name);
224 if (var) {
225 auto tvar = var.value();
226 TypeKind& tvar_kind = const_cast<TypeKind&>(tvar->kind);
227 tvar_kind = kind;
228 return tvar;
229 } else {
230 auto gvar = GlobalTypeVar(name, kind);
231 table->Add(name, gvar);
232 return gvar;
233 }
234}
235
236/*! \brief The parser class is the main interface to the parser.
237 * the parser is not currently exposed beyond this .cc file.
238 *
239 * The parser is initialized with a diagnostic context, an
240 * operator table, and a token stream.
241 *
242 * The rest of the internal state is used to map the human readable
243 * form to in-memory IR representation.
244 *
245 * The main entry point to the parser are a set of parsing methods
246 * such as `ParseModule` and `ParseExpr`.
247 *
248 * As with traditional recursive descent parsers the parsing methods
249 * are factored recursively just as one would do with a formal language
250 * grammar.
251 *
252 * You can view a recursive descent parser as a human friendly way to specify
253 * a state machine, and thus this factoring is necessary as the 'state' of this
254 * machine is the combination of the current parsing method and the next token.
255 *
256 * Parsing proceeds by matching a token and then dispatching to the appropriate
257 * method to parse the next tokens in the stream.
258 *
259 * For example if we are parsing a type and encounter a "Tensor" token we switch
260 * into a mode for parsing `[`, a shape, a comma, a data type and then a `]`.
261 *
262 * Certain matches like this are unambiguous and proceed in a straight line fashion
263 * once the initial token is found. Other parsing is more complex and requires some
264 * tricks to correctly parse.
265 *
266 * For example when we find a '(' in an expression context, it may be part of
267 * a tuple, the arguments to a call, or a parenthesized expression. The below code
268 * disambiguate these cases by factoring expression parsing into a series of methods
269 * which encode the parsing context and thus how to interpret the parenthesis.
270 *
271 * For more information one should be able to read the code in order starting with
272 * `ParseModule` or `ParseExpr`.
273 */
274class Parser {
275 public:
276 /*! \brief The version that the parser is parsing. */
277 SemVer version;
278
279 /*! \brief The IRModule we are building. */
280 IRModule module;
281
282 /*! \brief The diagnostic context used for error reporting. */
283 DiagnosticContext diag_ctx;
284
285 const Source& source;
286
287 /*! \brief The current position in the token stream. */
288 int pos;
289
290 /*! \brief The token stream for the parser. */
291 std::vector<Token> tokens;
292
293 /*! \brief The configured operator table. */
294 OperatorTable op_table;
295
296 /*! \brief Configure the whitespace mode, right now we ignore all whitespace. */
297 bool ignore_whitespace;
298
299 /*! \brief A global mapping for GlobalVar. */
300 InternTable<GlobalVar> global_names;
301
302 /*! \brief A global mapping for type definitions. */
303 InternTable<GlobalTypeVar> type_names;
304
305 /*! \brief A global mapping for constructor names. */
306 InternTable<Constructor> ctors;
307
308 /*! \brief A mapping from graph variable to expression, i.e., `%0 = expr`. */
309 std::unordered_map<int, Expr> graph_ctx;
310
311 /*! \brief The set of type scopes used for generics. */
312 ScopeStack<TypeVar> type_scopes;
313
314 /*! \brief The set of expression scopes used for lexical scope. */
315 ScopeStack<Var> expr_scopes;
316
317 /*! \brief The metadata section. */
318 MetaTable meta_table;
319
320 Parser(IRModule module, DiagnosticContext ctx, const Source& source, std::vector<Token> tokens,
321 OperatorTable op_table, MetaTable table)
322 : module(module),
323 diag_ctx(ctx),
324 source(source),
325 pos(0),
326 tokens(tokens),
327 op_table(op_table),
328 ignore_whitespace(true),
329 meta_table(table) {
330 InitializeGlobals();
331 InitializeTypeDefs();
332 }
333
334 /*! If we are parsing into a module with previously loaded data types we need to
335 * map constructor names and variable names in the global tables.
336 */
337 void InitializeTypeDefs() {
338 for (auto pair : this->module->type_definitions) {
339 type_names.Add(pair.first->name_hint, pair.first);
340 for (auto ctor : pair.second->constructors) {
341 ctors.Add(ctor->name_hint, ctor);
342 }
343 }
344 }
345
346 void InitializeGlobals() {
347 for (auto pair : this->module->functions) {
348 global_names.Add(pair.first->name_hint, pair.first);
349 }
350 }
351
352 /*! \brief Examine the next token in the stream, the current parser is configured to be
353 * whitespace insensitive so we will skip all whitespace or comment tokens. */
354 Token Peek() {
355 // For now we ignore all whitespace tokens and comments.
356 // We can tweak this behavior later to enable white space sensitivity in the parser.
357 while (pos < static_cast<int64_t>(tokens.size()) && ignore_whitespace &&
358 (tokens.at(pos)->token_type == TokenType::kWhitespace ||
359 tokens.at(pos)->token_type == TokenType::kNewline ||
360 tokens.at(pos)->token_type == TokenType::kLineComment ||
361 tokens.at(pos)->token_type == TokenType::kComment)) {
362 pos++;
363 }
364
365 if (pos < static_cast<int64_t>(tokens.size())) {
366 return Token(this->tokens.at(pos));
367 } else {
368 return Token::Null();
369 }
370 }
371
372 /*! \brief Lookahead by N tokens.
373 * \param n The number of tokens to lookahead.
374 * \return The Nth token.
375 */
376 Token Lookahead(int n) {
377 ICHECK_GE(n, 1) << "lookahead is only valid when n >= 1";
378
379 // We intend to skip n - 1 tokens, then return the nth.
380 auto old_pos = pos;
381 for (int i = 0; i < n - 1; i++) {
382 Peek();
383 pos++;
384 }
385
386 auto tok = Peek();
387 pos = old_pos;
388 return tok;
389 }
390
391 /*! \brief Consume a token, this method is the lowest level way to consume a token
392 * and will not ignore white space or look ahead in anyway.
393 *
394 * /param token_type The token type to match.
395 */
396 void Consume(const TokenType& token_type) {
397 if (tokens[pos]->token_type != token_type) {
398 this->diag_ctx.EmitFatal(Diagnostic::Error(tokens[pos]->span)
399 << "expected a " << Pretty(token_type) << " found "
400 << Pretty(Peek()->token_type));
401 }
402 pos++;
403 }
404
405 /*! Match a token in the stream, this will first invoke Peek, ignoring tokens such
406 * as whitespace or comments returning the first meaningful token.
407 *
408 * We then try and consume the requested token, this will trigger an error if the
409 * current token does not match the token_type.
410 */
411 Token Match(const TokenType& token_type) {
412 auto tok = Peek();
413 Consume(token_type);
414 return tok;
415 }
416
417 /*! Conditionally consume a token when it matches, this will never trigger an error
418 * as we guard against consuming the token before we do.
419 *
420 * Useful for matching optional tokens, effectively looksahead by one.
421 */
422 bool WhenMatch(const TokenType& token_type) {
423 VLOG(9) << "Parser::WhenMatch: Peek() == " << Peek();
424 if (Peek()->token_type == token_type) {
425 Consume(token_type);
426 return true;
427 } else {
428 return false;
429 }
430 }
431
432 /* \brief Add a graph binding to the parsing context
433 *
434 * For example if we parse %0 = add(...), map 0 -> add(...), etc.
435 */
436 void AddGraphBinding(const Token& token, const Expr& expr) {
437 auto graph_no = token.ToNumber();
438 this->graph_ctx.insert({graph_no, expr});
439 }
440
441 /* \brief Lookup a previously bound graph variable.
442 *
443 * Note: we take tokens in all lookup methods so that we
444 * that we can do error reporting based on token location.
445 */
446 Expr LookupGraphBinding(const Token& token) {
447 auto graph_no = token.ToNumber();
448 auto it = this->graph_ctx.find(graph_no);
449 if (it != this->graph_ctx.end()) {
450 return it->second;
451 } else {
452 LOG(FATAL) << "Local variable %" << graph_no << " has not yet been defined";
453 throw;
454 }
455 }
456
457 /*! \brief Bind a local variable in the expression scope.
458 *
459 * "x" -> Var("x"), these are needed to map from the raw string names
460 * to unique variable nodes.
461 * If a virtual device is specified, sets the virtual device of the variable.
462 */
463 Var BindVar(const std::string& name, const relay::Type& type_annotation,
464 Optional<VirtualDevice> virtual_device = Optional<VirtualDevice>()) {
465 auto var = Var(name, type_annotation);
466 var->virtual_device_ = virtual_device.value_or(VirtualDevice::FullyUnconstrained());
467 VLOG(1) << "Binding var named " << name << " to variable node " << PrettyPrint(var);
468 this->expr_scopes.Add(name, var);
469 return var;
470 }
471
472 /*! \brief Bind a local variable in the expression scope.
473 *
474 * "x" -> Var("x"), these are needed to map from the raw string names
475 * to unique variable nodes.
476 */
477 Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) {
478 auto var = Var(name, type_annotation);
479 this->expr_scopes.AddFreeVar(name, var);
480 return var;
481 }
482
483 /*! \brief Bind a type variable in the type scope.
484 *
485 * "A" -> TypeVar("A", ...), these are needed to map from raw string names
486 * to unique type variable nodes.
487 */
488 TypeVar BindTypeVar(const std::string& name, const TypeKind type_kind) {
489 auto type_var = TypeVar(name, type_kind);
490 this->type_scopes.Add(name, type_var);
491 return type_var;
492 }
493
494 /*! \brief Lookup a variable in the expression scope.
495 *
496 * Note: all lookup methods take tokens intentionally for error reporting information.
497 */
498 Var LookupLocal(const Token& local) {
499 auto var = this->expr_scopes.Lookup(local.ToString());
500 if (!var.defined()) {
501 diag_ctx.Emit(Diagnostic::Error(local->span)
502 << "this local variable has not been previously declared");
503 }
504 return var;
505 }
506
507 /*! \brief Lookup a variable in the type scope.
508 *
509 * Note: all lookup methods take tokens intentionally for error reporting information.
510 */
511 TypeVar LookupTypeVar(const Token& ident) {
512 auto var = this->type_scopes.Lookup(ident.ToString());
513 return var;
514 }
515
516 /*! \brief Add an expression scope to the scope stack. */
517 void PushScope() { this->expr_scopes.PushStack(); }
518
519 /*! \brief Remove N expression scopes from the scope stack. */
520 void PopScopes(int n) {
521 for (int i = 0; i < n; i++) {
522 this->expr_scopes.PopStack();
523 }
524 }
525
526 /*! \brief Add an type scope to the scope stack. */
527 void PushTypeScope() { this->type_scopes.PushStack(); }
528
529 /*! \brief Remove N type scopes from the scope stack. */
530 void PopTypeScopes(int n) {
531 for (int i = 0; i < n; i++) {
532 this->type_scopes.PopStack();
533 }
534 }
535
536 /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */
537 NDArray NumberToNDArray(const Token& token) {
538 if (token->token_type == TokenType::kInteger) {
539 return support::IntImmToNDArray(Downcast<tvm::IntImm>(token->data));
540 } else if (token->token_type == TokenType::kFloat) {
541 return support::FloatImmToNDArray(Downcast<tvm::FloatImm>(token->data));
542 } else {
543 LOG(FATAL) << "internal error: should only call this function on numeric tokens";
544 }
545 }
546
547 [[noreturn]] void ParseError(const Token& token, const std::string& msg) {
548 throw std::runtime_error(msg);
549 }
550
551 /*! \brief A parsing helper for a bracketed expression <start> <parser> <stop>. */
552 template <typename R>
553 R Bracket(TokenType open, TokenType close, std::function<R()> parser) {
554 Match(open);
555 R result = parser();
556 Match(close);
557 return result;
558 }
559
560 /*! \brief Parse `(` parser() `)`. */
561 template <typename R>
562 R Parens(std::function<R()> parser) {
563 return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser);
564 }
565
566 /*! \brief Parse `{` parser() `}`. */
567 template <typename R>
568 R Block(std::function<R()> parser) {
569 return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser);
570 }
571
572 template <typename R>
573 R WithSpan(std::function<R()> parser) {
574 auto start_span = Peek()->span;
575 VLOG(9) << "WithSpan: start_span = " << start_span;
576 R ast = parser();
577 if (ast.defined()) {
578 // The token at the head of the stream is now 1 past where we parsed. So we find its start
579 // position as its start and end, so that when we merge we only grow the spanned region
580 // to the start of the current stream.
581 auto span_pos = pos - 1;
582 while ((tokens.at(span_pos)->token_type == TokenType::kWhitespace ||
583 tokens.at(span_pos)->token_type == TokenType::kNewline ||
584 tokens.at(span_pos)->token_type == TokenType::kLineComment ||
585 tokens.at(span_pos)->token_type == TokenType::kComment)) {
586 span_pos--;
587 }
588 auto end_token = tokens.at(span_pos);
589 VLOG(9) << "WithSpan: end_span = " << end_token->span;
590 ast->span = start_span.Merge(end_token->span);
591 }
592 return ast;
593 }
594
595 struct MetaRef {
596 std::string type_key;
597 uint64_t node_index;
598 Span span;
599 MetaRef(std::string type_key, uint64_t node_index, Span span)
600 : type_key(type_key), node_index(node_index), span(span) {}
601 };
602
603 MetaRef MetaRefFromToken(const Token& tok) {
604 Call ref = Downcast<Call>(tok->data);
605 auto attrs = ref->attrs.as<MetaRefAttrs>();
606 auto type_key = attrs->node_type_key;
607 auto index = attrs->node_index;
608 return MetaRef(type_key, index, ref->span);
609 }
610
611 /*! \brief Parse a meta reference of the form `meta[type_key][node_index]`.
612 * For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]`
613 * the second, and so on.
614 */
615 ObjectRef ParseMetaRef() {
616 auto meta_ref_tok = Match(TokenType::kMetaReference);
617 auto meta_ref = MetaRefFromToken(meta_ref_tok);
618 auto it = this->meta_table.find(meta_ref.type_key);
619 if (it != this->meta_table.end()) {
620 auto nodes = (*it).second;
621 if (meta_ref.node_index < nodes.size()) {
622 return nodes[meta_ref.node_index];
623 } else {
624 this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span)
625 << "the node index `" << meta_ref.node_index
626 << "` is out of bounds for `" << meta_ref.type_key << "`");
627 return ObjectRef();
628 }
629 } else {
630 this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span)
631 << "no entry in the meta table for `" << meta_ref.type_key << "`");
632 return ObjectRef();
633 }
634 }
635 /*! \brief Parses a sequence beginning with a start token, separated by a seperator token, and
636 * ending with a stop token.
637 *
638 * The simple form being <start> (<parse()> <seperator>)* <stop>.
639 *
640 * This also provides a fourth argument which is allowed to run when the sequence which matches
641 * the inner sequence can not proceed.
642 *
643 * This is useful for parsing things like attributes which don't match the standard expression
644 * parsers but are contained within the stop token.
645 */
646 template <typename T>
647 Array<T> ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function<T()> parse,
648 std::function<bool()> before_stop = nullptr) {
649 VLOG(9) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep)
650 << " stop=" << ToString(stop);
651 Match(start);
652
653 // This is for the empty arguments list case, if we have <start> <leftovers> <stop> token stream
654 // we must parse leftovers, then match a stop token.
655 if (before_stop) {
656 auto did_parse = before_stop();
657 if (did_parse) {
658 Match(stop);
659 return {};
660 }
661 }
662
663 // This is the case in which we find an empty arguments lists and no leftovers.
664 if (WhenMatch(stop)) {
665 return Array<T>();
666 } else {
667 VLOG(9) << "Parser::ParseSequence: parse first";
668 auto data = parse();
669 Array<T> elements = {data};
670
671 if (WhenMatch(stop)) {
672 return elements;
673 // parse '( expr ',' * ')'
674 } else if (WhenMatch(sep)) {
675 while (true) {
676 VLOG(9) << "Parser::ParseSequence: parse element";
677 if (WhenMatch(stop)) {
678 break;
679 } else {
680 // If before stop is
681 if (before_stop) {
682 auto did_parse = before_stop();
683 if (did_parse) {
684 Match(stop);
685 return elements;
686 }
687 }
688 auto data = parse();
689 WhenMatch(sep);
690 elements.push_back(data);
691 }
692 }
693 return elements;
694 } else {
695 auto next = Peek();
696 this->diag_ctx.EmitFatal(Diagnostic::Error(next->span)
697 << "expected a " << Pretty(stop) << " found "
698 << Pretty(next->token_type));
699 return Array<T>(nullptr);
700 }
701 }
702 }
703
704 /*! \brief Parse a full IRModule. */
705 IRModule ParseModule() {
706 // Parse the semver header at the top of the module.
707 this->version = ParseSemVer();
708 // Parse the definitions.
709 auto defs = ParseDefinitions();
710 // Parse the metadata section at the end.
711 auto metadata = ParseMetadata();
712
713 Match(TokenType::kEndOfFile);
714
715 for (auto type_def : defs.types) {
716 module->AddTypeDef(type_def->header, type_def);
717 }
718
719 for (auto func : defs.funcs) {
720 module->Add(func.global, func.function, true);
721 }
722
723 return module;
724 }
725
726 /*! \brief Parse the semantic versioning header. */
727 SemVer ParseSemVer(bool required = true) {
728 if (Peek()->token_type == TokenType::kVersion) {
729 auto version = Match(TokenType::kVersion);
730 // TODO(@jroesch): we currently only support 0.0.5.
731 if (version.ToString() != "\"0.0.5\"") {
732 this->diag_ctx.Emit(Diagnostic::Error(version->span)
733 << "invalid semantic version `" << version.ToString() << "`");
734 }
735 } else if (required) {
736 this->diag_ctx.Emit(Diagnostic::Error(Peek()->span)
737 << "expected text format semantic version, found a "
738 << PrettyPrint(Peek()));
739
740 this->diag_ctx.Emit(Diagnostic::Help(Peek()->span)
741 << "you can annotate it as #[version = \"0.0.5\"]");
742 }
743 return SemVer(0, 0, 5);
744 }
745
746 /*! \brief Parse zero or more Relay definitions. */
747 Definitions ParseDefinitions() {
748 Definitions defs;
749
750 while (true) {
751 auto next = Peek();
752 switch (next->token_type) {
753 case TokenType::kDefn: {
754 Consume(TokenType::kDefn);
755 auto global_tok = Match(TokenType::kGlobal);
756 auto global_name = global_tok.ToString();
757 auto global = AddOrGet(&global_names, global_name);
758 auto func = WithSpan<relay::Function>([&]() { return ParseFunctionDef(); });
759 ICHECK(func->span.defined()) << "spans must be set in parser";
760 defs.funcs.push_back(GlobalFunc(global, func));
761 continue;
762 }
763 case TokenType::kTypeDef: {
764 defs.types.push_back(ParseTypeDef());
765 continue;
766 }
767 case TokenType::kExtern: {
768 Consume(TokenType::kExtern);
769 auto type_def = ParseTypeDef();
770 if (type_def->constructors.size()) {
771 diag_ctx.Emit(Diagnostic::Error(next->span)
772 << "an external type may not have any constructors");
773 }
774 defs.types.push_back(type_def);
775 }
776 default:
777 return defs;
778 }
779 }
780 }
781
782 /*! \brief Parse zero or more Relay type definitions. */
783 TypeData ParseTypeDef() {
784 // Match the `type` keyword.
785 Match(TokenType::kTypeDef);
786 // Parse the type's identifier.
787 auto type_tok = Match(TokenType::kIdentifier);
788 auto type_id = type_tok.ToString();
789 auto type_global = AddOrGet(&type_names, type_id, TypeKind::kAdtHandle);
790
791 Array<TypeVar> generics;
792
793 bool should_pop = false;
794 if (Peek()->token_type == TokenType::kLSquare) {
795 // If we have generics we need to add a type scope.
796 PushTypeScope();
797 should_pop = true;
798 generics = ParseSequence<TypeVar>(
799 TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() {
800 auto type_var_name = Match(TokenType::kIdentifier).ToString();
801 return BindTypeVar(type_var_name, TypeKind::kType);
802 });
803 }
804
805 Array<tvm::Constructor> ctors;
806 if (Peek()->token_type == TokenType::kLCurly) {
807 // Parse the list of constructors.
808 ctors = ParseSequence<tvm::Constructor>(
809 TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() {
810 // First match the name of the constructor.
811 auto ctor_tok = Match(TokenType::kIdentifier);
812 auto ctor_name = ctor_tok.ToString();
813
814 Constructor ctor;
815 // Match the optional field list.
816 if (Peek()->token_type != TokenType::kOpenParen) {
817 ctor = tvm::Constructor(ctor_name, {}, type_global);
818 } else {
819 auto arg_types =
820 ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma,
821 TokenType::kCloseParen, [&]() { return ParseType(); });
822 ctor = tvm::Constructor(ctor_name, arg_types, type_global);
823 }
824
825 ICHECK(ctor.defined());
826
827 try {
828 this->ctors.Add(ctor_name, ctor);
829 } catch (const DuplicateKeyError& e) {
830 this->diag_ctx.EmitFatal(Diagnostic::Error(ctor_tok->span)
831 << "a constructor with the name "
832 << "`" << ctor_name << "` "
833 << "was previously defined");
834 }
835
836 return ctor;
837 });
838 }
839
840 // Now pop the type scope.
841 if (should_pop) {
842 PopTypeScopes(1);
843 }
844
845 return TypeData(type_global, generics, ctors);
846 }
847
848 std::string HackTokensAsString(int n) {
849 std::stringstream key;
850 n = std::min(static_cast<int>(tokens.size() - pos), n);
851 for (int i = 0; i < n; i++) {
852 key << ToString(tokens.at(pos + i)->token_type);
853 }
854 return key.str();
855 }
856
857 std::vector<Rule> ParseOp() {
858 std::vector<Rule> matched;
859 Peek();
860 for (int i = 4; i > 0; i--) {
861 auto key = HackTokensAsString(i);
862 auto it = this->op_table.this_is_a_hack.find(key);
863 if (it != this->op_table.this_is_a_hack.end()) {
864 pos = pos + i;
865 matched.push_back(it->second);
866 }
867 }
868
869 return matched;
870 }
871
872 /*! \brief Parse a single Relay expression. */
873 Expr ParseExpr() {
874 VLOG(9) << "Parser::ParseExpr";
875 return WithSpan<Expr>([this] {
876 std::vector<Expr> exprs;
877
878 while (true) {
879 VLOG(9) << "Parser::ParseExpr: parsing a single expression";
880 auto next = Peek();
881 switch (next->token_type) {
882 // For graph or let, match first rhs, then invoke ParseBindingExpr
883 // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue
884 case TokenType::kLCurly: {
885 // NB: Might need to optimize to remove deep recursion.
886 // Stack should only grow proportionally to the number of
887 // nested scopes.
888 // Parses `{` expression `}`.
889 auto block = WithSpan<Expr>([&]() {
890 return Bracket<Expr>(TokenType::kLCurly, TokenType::kRCurly, [&]() {
891 PushScope();
892 auto expr = ParseExpr();
893 PopScopes(1);
894 return expr;
895 });
896 });
897 exprs.push_back(block);
898 break;
899 }
900 case TokenType::kFreeVar: {
901 Consume(TokenType::kFreeVar);
902 auto var_token = Match(TokenType::kLocal);
903
904 Type type;
905 if (WhenMatch(TokenType::kColon)) {
906 type = ParseType();
907 } else {
908 type = IncompleteType();
909 }
910
911 BindFreeVar(var_token.ToString(), type);
912 break;
913 }
914 // Parses `let ...`;
915 case TokenType::kLet:
916 exprs.push_back(ParseBindingExpr());
917 break;
918 case TokenType::kMatch:
919 case TokenType::kPartialMatch: {
920 bool is_total = next->token_type == TokenType::kMatch;
921 Consume(next->token_type);
922 exprs.push_back(ParseMatch(is_total));
923 break;
924 }
925
926 // %x ...
927 case TokenType::kGraph:
928 if (Lookahead(2)->token_type == TokenType::kEqual) {
929 exprs.push_back(ParseBindingExpr());
930 break;
931 }
932 // intentional fall through here.
933 default: {
934 exprs.push_back(ParseExprBinOp());
935 break;
936 }
937 }
938
939 if (!WhenMatch(TokenType::kSemicolon)) {
940 break;
941 }
942 }
943
944 ICHECK_GE(exprs.size(), 1);
945
946 if (exprs.size() == 1) {
947 // ICHECK(exprs[0].defined() && exprs[0]->span.defined())
948 // << "parser must set expression spans.\n"
949 // << exprs[0];
950 return exprs[0];
951 } else {
952 auto body = exprs.back();
953 exprs.pop_back();
954 while (exprs.size()) {
955 auto value = exprs.back();
956 ICHECK(value->span.defined()) << "parser must set expression spans.";
957 exprs.pop_back();
958 body = relay::Let(Var("", IncompleteType()), value, body, value->span.Merge(body->span));
959 }
960 ICHECK(body->span.defined()) << "parser must set expression spans.";
961 return body;
962 }
963 });
964 }
965
966 /*! \brief Parse a "binding expression"; an expression where
967 * a graph or let variable is bound.
968 *
969 * In order to avoid stack overflow this is implemented in a special
970 * iterative way to keep stack depth constant in a long chain of bindings.
971 */
972 Expr ParseBindingExpr() {
973 // We use a loop here so that the stack depth
974 // does not grow linearly with a sequence of
975 // graph or let bindings.
976 //
977 // Assuming we start at call depth k, we will
978 // enter k + c call frames to parse the RHS
979 // of the bindings where `c` is the depth
980 // of recursion needed by RHS.
981 //
982 // If RHS is a call expresssion the c=1.
983 //
984 // Once we have parsed the RHS we will be
985 // back at depth K, and will return to
986 // this loop header to parse another
987 // graph or let binding.
988 //
989 // This ensures for n sequential bindings
990 // the call depth will be the same before
991 // and after parsing the n bindings.
992 VLOG(9) << "Parser::ParseBindingExpr";
993 std::vector<std::tuple<Var, Expr, Span>> bindings;
994 int scopes = 0;
995
996 while (true) {
997 auto next = Peek();
998 if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) {
999 Match(TokenType::kGraph);
1000 Match(TokenType::kEqual);
1001 auto val = this->ParseExprBinOp();
1002 Match(TokenType::kSemicolon);
1003 AddGraphBinding(next, val);
1004 } else if (next->token_type == TokenType::kLet) {
1005 auto span = next->span;
1006 // Parse the 'let'.
1007 Consume(TokenType::kLet);
1008
1009 // Parse the local '%<id>'.
1010 auto local_tok = Match(TokenType::kLocal);
1011 auto string = local_tok.ToString();
1012
1013 // Parse the optional type annotation (':' <type>).
1014 Type type;
1015 if (WhenMatch(TokenType::kColon)) {
1016 type = ParseType();
1017 }
1018
1019 auto var = BindVar(string, type);
1020
1021 // Parse the '=';
1022 Match(TokenType::kEqual);
1023
1024 // Parse the body, and the ';'.
1025 auto val = this->ParseExprBinOp();
1026 Consume(TokenType::kSemicolon);
1027
1028 // Add the bindings to the local data structure.
1029 std::tuple<relay::Var, relay::Expr, Span> tuple(var, val, span);
1030 bindings.push_back(tuple);
1031 scopes++;
1032 PushScope();
1033 } else {
1034 // This is the only case we will increase the stack
1035 // depth.
1036 //
1037 // If we parse a program which is a sequence of N bindings
1038 // followed by a single body expression we will end up with
1039 // a call depth of 3, the first call to ParseExpr, then
1040 // ParseBindingExpr, then finally ParseExpr once more.
1041
1042 auto body = this->ParseExpr();
1043
1044 // Remove the same number of scopes we added.
1045 PopScopes(scopes);
1046
1047 if (bindings.size() == 0) {
1048 return body;
1049 } else {
1050 // We can now build the let binding up backwards.
1051 for (auto binding = bindings.rbegin(); binding != bindings.rend(); binding++) {
1052 auto span = body->span.Merge(std::get<2>(*binding));
1053 body = relay::Let(std::get<0>(*binding), std::get<1>(*binding), body, span);
1054 }
1055 return body;
1056 }
1057 }
1058 }
1059 }
1060
1061 /*! Parse a function definition without a leading keyword or identifier.
1062 *
1063 * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }.
1064 */
1065 Function ParseFunctionDef() {
1066 VLOG(9) << "Parser::ParseFunctionDef";
1067 return WithSpan<Function>([&]() {
1068 PushScope();
1069 PushTypeScope();
1070
1071 Array<TypeVar> generics;
1072 if (Peek()->token_type == TokenType::kLSquare) {
1073 generics = ParseSequence<TypeVar>(
1074 TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() {
1075 auto type_var_name = Match(TokenType::kIdentifier).ToString();
1076 return BindTypeVar(type_var_name, TypeKind::kType);
1077 });
1078 }
1079
1080 Map<String, ObjectRef> raw_attrs;
1081
1082 auto params = ParseSequence<Var>(
1083 TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
1084 [&]() {
1085 auto token = Match(TokenType::kLocal);
1086 auto string = token.ToString();
1087
1088 // The fake attributes where the virtual device is specified.
1089 VirtualDevice virtual_device;
1090 if (WhenMatch(TokenType::kLCurly)) {
1091 Map<String, ObjectRef> fake_attrs = ParseAttrs();
1092 VLOG(9) << "Fake attributes for function parameter: " << fake_attrs;
1093 Match(TokenType::kRCurly);
1094 if (fake_attrs.size() == 1 && fake_attrs.count(kVirtualDevice)) {
1095 ICHECK(fake_attrs[kVirtualDevice].as<VirtualDeviceNode>())
1096 << "Expected the " << kVirtualDevice
1097 << " to have type VirtualDeviceNode, but got " << virtual_device->GetTypeKey();
1098 virtual_device = Downcast<VirtualDevice>(fake_attrs[kVirtualDevice]);
1099 }
1100 }
1101
1102 Type type;
1103 if (WhenMatch(TokenType::kColon)) {
1104 type = ParseType();
1105 }
1106 return BindVar(string, type, virtual_device);
1107 },
1108 [&] {
1109 auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
1110 auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual;
1111
1112 if (is_ident && next_is_equal) {
1113 raw_attrs = ParseAttrs();
1114 return true;
1115 }
1116
1117 return false;
1118 });
1119
1120 Type ret_type;
1121 if (WhenMatch(TokenType::kMinus)) {
1122 Match(TokenType::kRAngle);
1123 ret_type = ParseType();
1124 }
1125
1126 auto body = Block<Expr>([&]() { return ParseExpr(); });
1127
1128 PopTypeScopes(1);
1129 PopScopes(1);
1130
1131 // TODO(@jroesch): attributes should never be null, they should always be empty.
1132 if (raw_attrs.size()) {
1133 // Promote kVirtualDevice to first-class
1134 if (raw_attrs.count(kVirtualDevice)) {
1135 ObjectRef vid = raw_attrs.at(kVirtualDevice);
1136 ICHECK(vid.as<VirtualDeviceNode>())
1137 << "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got "
1138 << vid->GetTypeKey();
1139
1140 DictAttrs attrs;
1141 // Don't fill the raw_attrs in if there's nothing other than kVirtualDevice in the
1142 // attributes
1143 if (raw_attrs.size() > 1) {
1144 raw_attrs.erase(kVirtualDevice);
1145 attrs = DictAttrs(raw_attrs);
1146 }
1147 Function func = relay::Function(params, body, ret_type, generics, attrs);
1148 func->virtual_device_ = vid;
1149 return func;
1150 } else {
1151 return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
1152 }
1153 } else {
1154 return relay::Function(params, body, ret_type, generics, tvm::DictAttrs());
1155 }
1156 });
1157 }
1158
1159 /*! \brief Parse an if-expression. */
1160 Expr ParseIf() {
1161 return WithSpan<Expr>([&]() {
1162 VLOG(9) << "Parser::ParseIf";
1163 Consume(TokenType::kIf);
1164
1165 auto guard = WithSpan<Expr>([&] { return Parens<Expr>([&] { return ParseExpr(); }); });
1166
1167 auto true_branch = Block<Expr>([&] {
1168 this->PushScope();
1169 auto expr = ParseExpr();
1170 this->PopScopes(1);
1171 return expr;
1172 });
1173
1174 Match(TokenType::kElse);
1175
1176 auto false_branch = Block<Expr>([&] {
1177 this->PushScope();
1178 auto expr = ParseExpr();
1179 this->PopScopes(1);
1180 return expr;
1181 });
1182
1183 return relay::If(guard, true_branch, false_branch);
1184 });
1185 }
1186
1187 /* This factors parsing a list of patterns for both tuples, and constructors. */
1188 Array<Pattern> ParsePatternList() {
1189 return ParseSequence<Pattern>(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
1190 [&] { return ParsePattern(); });
1191 }
1192
1193 /*! \brief Parses a pattern for a match expression.
1194 *
1195 * A pattern is either a wildcard `_`, a local `%name`,
1196 * a constructor `C(p1, ..., pn)` or tuple `(p1, ..., pn).
1197 *
1198 * This function recursively parses a pattern.
1199 */
1200 Pattern ParsePattern() {
1201 VLOG(9) << "Parser::ParsePattern";
1202 auto next = Peek();
1203 switch (next->token_type) {
1204 case TokenType::kUnderscore: {
1205 Match(TokenType::kUnderscore);
1206 return PatternWildcard();
1207 }
1208 case TokenType::kLocal: {
1209 auto id = Match(TokenType::kLocal);
1210 Type type_annotation;
1211 if (WhenMatch(TokenType::kColon)) {
1212 type_annotation = ParseType();
1213 }
1214 auto var = BindVar(id.ToString(), type_annotation);
1215 return PatternVar(var);
1216 }
1217 case TokenType::kIdentifier: {
1218 auto id = Match(TokenType::kIdentifier);
1219 auto ctor = ctors.Get(id.ToString());
1220 if (!ctor) {
1221 diag_ctx.EmitFatal(
1222 // TODO(@jroesch): split into error and help
1223 // deal with multiple rendering
1224 Diagnostic::Error(id->span)
1225 << "undefined constructor name `" << id.ToString()
1226 << "`, perhaps you intended to write a"
1227 << "pattern variable, considering changing this to `%" << id.ToString() << "`");
1228 }
1229 if (Peek()->token_type == TokenType::kOpenParen) {
1230 auto fields = ParsePatternList();
1231 return PatternConstructor(ctor.value(), fields);
1232 } else {
1233 return PatternConstructor(ctor.value(), {});
1234 }
1235 }
1236 default:
1237 return PatternTuple(ParsePatternList());
1238 }
1239 }
1240
1241 Clause ParseMatchArm() {
1242 PushScope();
1243 auto pattern = ParsePattern();
1244 Match(TokenType::kEqual);
1245 Consume(TokenType::kRAngle);
1246 auto expr = ParseExpr();
1247 PopScopes(1);
1248 return Clause(pattern, expr);
1249 }
1250
1251 Expr ParseMatch(bool is_total) {
1252 return WithSpan<Expr>([&]() {
1253 Expr scrutinee = ParseAtomicExpr();
1254
1255 Array<Clause> clauses =
1256 ParseSequence<Clause>(TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly,
1257 [&] { return ParseMatchArm(); });
1258
1259 return relay::Match(scrutinee, clauses, is_total);
1260 });
1261 }
1262
1263 Expr ParseExprBinOp() {
1264 VLOG(9) << "Parser::ParseExprBinOp";
1265 return WithSpan<Expr>([this] {
1266 // We must parse at least one expression, the default
1267 // case is that there is no operator and we will fall
1268 // through.
1269 std::vector<Expr> exprs;
1270 Expr expr = WithSpan<Expr>([this] { return ParseCallExpr(); });
1271
1272 exprs.push_back(expr);
1273
1274 // Now we parse an optional op.
1275 std::vector<Rule> ops;
1276
1277 // We will now parse 0 or more operator occurrences.
1278 while (true) {
1279 auto opt_op = ParseOp();
1280
1281 // If we didn't parse one we done.
1282 if (opt_op.size() == 0) {
1283 break;
1284 }
1285
1286 // Read the operation we parsed;
1287 auto op = opt_op[0];
1288
1289 Expr right = WithSpan<Expr>([this] { return ParseCallExpr(); });
1290 ICHECK(right->span.defined());
1291
1292 // If the operator stack is empty
1293 // we parse an operator and expression
1294 // and push them to stacks, then
1295 // continue.
1296 if (ops.size() == 0) {
1297 ops.push_back(op);
1298 exprs.push_back(right);
1299 continue;
1300 }
1301
1302 if (op.precedence > ops.back().precedence ||
1303 (op.precedence == ops.back().precedence && op.left_assoc == false)) {
1304 ops.push_back(op);
1305 exprs.push_back(right);
1306 continue;
1307 }
1308
1309 while (ops.size() && (op.precedence < ops.back().precedence ||
1310 (op.precedence == ops.back().precedence && op.left_assoc == true))) {
1311 Rule new_op = ops.back();
1312 ops.pop_back();
1313 Expr right = exprs.back();
1314 exprs.pop_back();
1315 Expr left = exprs.back();
1316 exprs.pop_back();
1317 ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op;
1318 exprs.push_back(
1319 relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span)));
1320 }
1321
1322 exprs.push_back(right);
1323 ops.push_back(op);
1324 }
1325
1326 while (ops.size()) {
1327 Rule new_op = ops.back();
1328 ops.pop_back();
1329 Expr right = exprs.back();
1330 exprs.pop_back();
1331 Expr left = exprs.back();
1332 exprs.pop_back();
1333 ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op;
1334 exprs.push_back(
1335 relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span)));
1336 }
1337
1338 ICHECK_EQ(ops.size(), 0) << "No operations should be left on the operation stack.";
1339
1340 ICHECK_EQ(exprs.size(), 1)
1341 << "Only a single expression should be left on the expression stack.";
1342
1343 return exprs[0];
1344 });
1345 }
1346
1347 ObjectRef ParseAttributeValue() {
1348 VLOG(9) << "Parser::ParseAttributeValue";
1349 auto next = Peek();
1350 switch (next->token_type) {
1351 case TokenType::kFloat:
1352 case TokenType::kInteger:
1353 case TokenType::kBoolean:
1354 case TokenType::kStringLiteral:
1355 return Match(next->token_type)->data;
1356 case TokenType::kMetaReference:
1357 return ParseMetaRef();
1358 case TokenType::kLSquare: {
1359 return ParseSequence<ObjectRef>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare,
1360 [&]() { return ParseAttributeValue(); });
1361 }
1362 case TokenType::kOpenParen: {
1363 // TODO(@jroesch: need to figure out bracket vs. sequence)
1364 // return ParseSequence<ObjectRef>(TokenType::kOpenParen, TokenType::kComma,
1365 // TokenType::kCloseParen,
1366 // [&]() { return ParseAttributeValue(); });
1367 return Bracket<ObjectRef>(TokenType::kOpenParen, TokenType::kCloseParen,
1368 [&]() { return ParseAttributeValue(); });
1369 }
1370 // TODO(@jroesch): not sure about this being the right way to handle nulls.
1371 case TokenType::kIdentifier: {
1372 if (auto text = next->data.as<tvm::StringObj>()) {
1373 std::string id = GetRef<String>(text);
1374 if (id == "nullptr") {
1375 Match(TokenType::kIdentifier);
1376 return ObjectRef();
1377 }
1378 if (id == "None") {
1379 Match(TokenType::kIdentifier);
1380 return Optional<ObjectRef>();
1381 }
1382 }
1383 }
1384 default:
1385 return ParseAtomicExpr();
1386 }
1387 }
1388
1389 Map<String, ObjectRef> ParseAttrs() {
1390 VLOG(9) << "Parser::ParseAttrs";
1391 Map<String, ObjectRef> kwargs;
1392 while (Peek()->token_type == TokenType::kIdentifier) {
1393 auto key = GetHierarchicalName(ParseHierarchicalName().data);
1394 Match(TokenType::kEqual);
1395 // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side.
1396 auto value = ParseAttributeValue();
1397 // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text
1398 // format is bad.
1399 kwargs.Set(key, value);
1400 WhenMatch(TokenType::kComma);
1401 }
1402 VLOG(9) << "Parser::ParseAttrs: kwargs=" << kwargs;
1403 return kwargs;
1404 }
1405
1406 Expr ParseCallArgs(Expr op) {
1407 ICHECK(op.defined()) << "the operator must be defined";
1408
1409 VLOG(9) << "Parser::ParseCallArgs";
1410 Attrs attrs;
1411 std::string op_key;
1412 bool is_op = false;
1413
1414 if (auto op_node = op.as<OpNode>()) {
1415 is_op = true;
1416 op_key = op_node->attrs_type_key;
1417 }
1418
1419 if (Peek()->token_type == TokenType::kOpenParen) {
1420 Array<Expr> args = ParseSequence<Expr>(
1421 TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
1422 [&] { return ParseExpr(); },
1423 [&] {
1424 auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
1425 auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual;
1426 auto is_pretty_attrs = is_ident && next_is_equal;
1427 auto is_meta_next = Lookahead(1)->token_type == TokenType::kMetaReference;
1428 // TODO(@jroesch): might not handle trailing comma
1429 auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen;
1430 auto is_meta_attrs = is_meta_next && last_meta;
1431
1432 if (is_pretty_attrs || is_meta_attrs) {
1433 if (is_meta_attrs) {
1434 auto meta_ref = ParseMetaRef();
1435 if (meta_ref.as<BaseAttrsNode>()) {
1436 attrs = Downcast<Attrs>(meta_ref);
1437 } else {
1438 // Not awesome parsing code here.
1439 this->pos--;
1440 return false;
1441 }
1442 } else {
1443 auto raw_attrs = ParseAttrs();
1444 if (is_op && op_key.size()) {
1445 auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs);
1446 ICHECK(attr_obj.defined());
1447 attrs = Downcast<Attrs>(attr_obj);
1448 } else if (raw_attrs.count("attrs_type_key")) {
1449 String attr_key = Downcast<String>(raw_attrs["attrs_type_key"]);
1450 if (attr_key.size()) {
1451 raw_attrs.erase("attrs_type_key");
1452 auto attr_obj =
1453 tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs);
1454 ICHECK(attr_obj.defined());
1455 attrs = Downcast<Attrs>(attr_obj);
1456 }
1457 } else {
1458 this->diag_ctx.EmitFatal(Diagnostic::Error(op->span)
1459 << "unable to determine the 'attrs_type_key' with which "
1460 "to represent the call attributes for this operator");
1461 }
1462 }
1463 return true;
1464 }
1465 return false;
1466 });
1467
1468 if (!attrs.defined()) {
1469 if (is_op && op_key.size()) {
1470 auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, {});
1471 ICHECK(attr_obj.defined());
1472 attrs = Downcast<Attrs>(attr_obj);
1473 }
1474 }
1475
1476 // TODO(@jroesch): in a secondary pass adjust spans.
1477 return Expr(Call(op, args, attrs, {}));
1478 } else {
1479 return Expr();
1480 }
1481
1482 return Expr();
1483 }
1484
1485 Expr ParseCallExpr() {
1486 VLOG(9) << "Parser::ParseCallExpr";
1487 return WithSpan<Expr>([this] {
1488 Expr expr = ParseAtomicExpr();
1489 // Parse as many call args as possible, building up expression
1490 //
1491 // NB(@jroesch): this seems like a hack but in order to parse curried functions
1492 // and avoid complex grammar we will parse multiple call lists in a row.
1493 while (Peek()->token_type == TokenType::kOpenParen) {
1494 auto new_expr = ParseCallArgs(expr);
1495
1496 if (new_expr.defined()) {
1497 expr = new_expr;
1498 } else {
1499 break;
1500 }
1501 }
1502
1503 // We need a zero-arity case for constructors.
1504 if (auto ctor_node = expr.as<ConstructorNode>()) {
1505 if (ctor_node->inputs.size() == 0) {
1506 return Expr(Call(expr, {}));
1507 }
1508 }
1509
1510 return expr;
1511 });
1512 }
1513
1514 Expr GetOp(const std::string& op_name, const Span& span) {
1515 VLOG(9) << "op_name=" << op_name << " span=" << span;
1516 try {
1517 return Op::Get(op_name);
1518 } catch (const Error& e) {
1519 // we can relax this, but probably need to relax checks or return non-null here.
1520 this->diag_ctx.EmitFatal(Diagnostic::Error(span)
1521 << "operator `" << op_name
1522 << "` not found, perhaps you forgot to register it?");
1523 return Expr();
1524 }
1525 }
1526
1527 Expr ParseAtomicExpr() {
1528 VLOG(9) << "Parser::ParseAtomicExpr";
1529 Expr expr = WithSpan<Expr>([this] {
1530 auto next = Peek();
1531 switch (next->token_type) {
1532 case TokenType::kInteger:
1533 case TokenType::kFloat: {
1534 Consume(next->token_type);
1535 auto number = NumberToNDArray(next);
1536 Expr e = Constant(number, next->span);
1537 ICHECK(e->span.defined()) << "constant spans must be defined";
1538 return e;
1539 }
1540 case TokenType::kBoolean: {
1541 Consume(TokenType::kBoolean);
1542 int64_t value = Downcast<tvm::Integer>(next->data).IntValue();
1543 Expr e = Constant(support::BoolToNDArray(value), next->span);
1544 ICHECK(e->span.defined()) << "constant spans must be defined";
1545 return e;
1546 }
1547 // Parse a local of the form `%x`.
1548 case TokenType::kLocal: {
1549 Consume(TokenType::kLocal);
1550 return Expr(LookupLocal(next));
1551 }
1552 // Parse a local of the form `@x`.
1553 case TokenType::kGlobal: {
1554 auto global_name = next.ToString();
1555 Consume(TokenType::kGlobal);
1556 auto global = AddOrGet(&global_names, global_name);
1557 return Expr(global);
1558 }
1559 // Parse a local of the form `x`.
1560 // Right now we fail to parse `x.y`.
1561 case TokenType::kIdentifier: {
1562 auto ctor = ctors.Get(next.ToString());
1563 if (ctor) {
1564 Consume(TokenType::kIdentifier);
1565 return Expr(ctor.value());
1566 } else {
1567 auto spanned_idents = ParseHierarchicalName();
1568 auto idents = spanned_idents.data;
1569 auto span = spanned_idents.span;
1570 return GetOp(GetHierarchicalName(idents), span);
1571 }
1572 }
1573 case TokenType::kGraph: {
1574 Consume(TokenType::kGraph);
1575 return LookupGraphBinding(next);
1576 }
1577 case TokenType::kMetaReference: {
1578 return Downcast<Expr>(ParseMetaRef());
1579 }
1580 case TokenType::kFn: {
1581 Consume(TokenType::kFn);
1582 Expr e = ParseFunctionDef();
1583 ICHECK(e->span.defined()) << "function spans must be defined.\n" << e;
1584 return e;
1585 }
1586 case TokenType::kIf: {
1587 Expr e = ParseIf();
1588 return e;
1589 }
1590 case TokenType::kRef: {
1591 Consume(TokenType::kRef);
1592 Match(TokenType::kOpenParen);
1593 auto ref_value = ParseExpr();
1594 Match(TokenType::kCloseParen);
1595 return static_cast<Expr>(RefCreate(ref_value));
1596 }
1597 case TokenType::kRefRead: {
1598 return WithSpan<Expr>([&]() {
1599 Consume(TokenType::kRefRead);
1600 Match(TokenType::kOpenParen);
1601 auto ref = ParseExpr();
1602 Match(TokenType::kCloseParen);
1603 return static_cast<Expr>(RefRead(ref));
1604 });
1605 }
1606 case TokenType::kRefWrite: {
1607 return WithSpan<Expr>([&]() {
1608 Consume(TokenType::kRefWrite);
1609 Match(TokenType::kOpenParen);
1610 auto ref = ParseExpr();
1611 Match(TokenType::kComma);
1612 auto value = ParseExpr();
1613 Match(TokenType::kCloseParen);
1614 return static_cast<Expr>(RefWrite(ref, value));
1615 });
1616 }
1617 case TokenType::kOpenParen: {
1618 Span sp = next->span;
1619 Consume(TokenType::kOpenParen);
1620 // parse '(' ')'
1621 if (WhenMatch(TokenType::kCloseParen)) {
1622 return Expr(Tuple(Array<Expr>()));
1623 } else {
1624 Expr subexpr = ParseExpr();
1625 // parse '(' expr ')'
1626 if (WhenMatch(TokenType::kCloseParen)) {
1627 return subexpr;
1628 // parse '( expr ',' * ')'
1629 } else if (WhenMatch(TokenType::kComma)) {
1630 Array<Expr> exprs = {subexpr};
1631 while (true) {
1632 if (WhenMatch(TokenType::kCloseParen)) {
1633 break;
1634 } else {
1635 auto element = ParseExpr();
1636 auto comma = Peek();
1637 if (WhenMatch(TokenType::kComma)) {
1638 sp = sp.Merge(element->span.Merge(comma->span));
1639 } else {
1640 sp = sp.Merge(element->span);
1641 }
1642 exprs.push_back(element);
1643 }
1644 }
1645 Expr tuple = Tuple(exprs, sp);
1646 ICHECK(tuple->span.defined()) << "tuple span should be defined";
1647 return tuple;
1648 }
1649 }
1650 }
1651 default: {
1652 this->diag_ctx.EmitFatal(Diagnostic::Error(next->span)
1653 << "expected an expression found " << Pretty(next->token_type));
1654 return Expr();
1655 }
1656 }
1657 });
1658
1659 if (WhenMatch(TokenType::kPeriod)) {
1660 auto token = Match(TokenType::kInteger);
1661 auto index = token.ToNumber();
1662 auto span = token->span.Merge(expr->span);
1663 VLOG(9) << "Parser::ParseAtomicExpr: tuple get item";
1664 return relay::TupleGetItem(expr, index, span);
1665 } else {
1666 return expr;
1667 }
1668 }
1669
1670 /*! \brief Parse a hierarchical name.
1671 *
1672 * The tokenizer produces a token stream of <id1> . <id2>
1673 * and so on for names of the form `nn.conv2d`.
1674 * Currently we only use string names everywhere instead
1675 * of a notion of a hierarchical name.
1676 *
1677 * The below utility reassembles a token stream into a
1678 * single stream inserting the required periods needed
1679 * to look up registered names.
1680 */
1681 Spanned<Array<String>> ParseHierarchicalName() {
1682 Array<String> idents;
1683 Span span;
1684 while (Peek()->token_type == TokenType::kIdentifier) {
1685 auto token = Peek();
1686
1687 if (span.defined()) {
1688 span = span.Merge(token->span);
1689 } else {
1690 span = token->span;
1691 }
1692
1693 auto name = token.ToString();
1694 idents.push_back(name);
1695 Consume(TokenType::kIdentifier);
1696
1697 // Keep parsing while we see a trailing period.
1698 if (Peek()->token_type == TokenType::kPeriod) {
1699 Consume(TokenType::kPeriod);
1700 continue;
1701 } else {
1702 // No more periods means we are done!
1703 break;
1704 }
1705 }
1706
1707 return Spanned<Array<String>>(idents, span);
1708 }
1709
1710 std::string GetHierarchicalName(Array<String> idents) {
1711 ICHECK_NE(idents.size(), 0);
1712 std::stringstream hierarchical_name;
1713 int i = 0;
1714 int periods = idents.size() - 1;
1715 for (auto ident : idents) {
1716 hierarchical_name << ident;
1717 if (i < periods) {
1718 hierarchical_name << ".";
1719 i++;
1720 }
1721 }
1722 return hierarchical_name.str();
1723 }
1724
1725 /*! \brief Parse a shape. */
1726 Array<tvm::PrimExpr> ParseShape() {
1727 auto dims = ParseSequence<tvm::PrimExpr>(
1728 TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() {
1729 tvm::PrimExpr dim;
1730 if (Peek()->token_type == TokenType::kMetaReference) {
1731 dim = Downcast<tvm::PrimExpr>(ParseMetaRef());
1732 } else if (WhenMatch(TokenType::kQuestion)) {
1733 dim = tvm::tir::Any();
1734 } else {
1735 dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data);
1736 }
1737
1738 return dim;
1739 });
1740 return dims;
1741 }
1742
1743 /*! \brief Parse a function type. */
1744 Type ParseFunctionType() {
1745 auto ty_params = ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma,
1746 TokenType::kCloseParen, [&]() { return ParseType(); });
1747
1748 Match(TokenType::kMinus);
1749 Match(TokenType::kRAngle);
1750 auto ret_type = ParseType();
1751
1752 return relay::FuncType(ty_params, ret_type, {}, {});
1753 }
1754
1755 // Parses a user defined ADT or type variable.
1756 Type ParseNonPrimitiveType(const Token& tok) {
1757 return WithSpan<Type>([&]() {
1758 auto name = tok.ToString();
1759 Type head_type = LookupTypeVar(tok);
1760
1761 if (!head_type.defined()) {
1762 // head_type = type_names.Get(name);
1763 head_type = AddOrGet(&type_names, name, TypeKind::kAdtHandle);
1764 }
1765
1766 if (!head_type.defined()) {
1767 diag_ctx.EmitFatal(Diagnostic::Error(tok->span)
1768 << "the type constructor `" << name << "` is undefined");
1769 }
1770
1771 Array<Type> arg_types;
1772 if (Peek()->token_type == TokenType::kLSquare) {
1773 arg_types = ParseSequence<Type>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare,
1774 [&]() { return ParseType(); });
1775 }
1776
1777 if (arg_types.size()) {
1778 return static_cast<Type>(TypeCall(head_type, arg_types));
1779 } else {
1780 if (head_type.as<GlobalTypeVarNode>()) {
1781 return static_cast<Type>(TypeCall(head_type, {}));
1782 } else {
1783 return static_cast<Type>(head_type);
1784 }
1785 }
1786 });
1787 }
1788
1789 /*! \brief Parses a TVM type.
1790 *
1791 * This matches either a `Tensor[shape, dtype]`, a user defined ADT, a tuple type,
1792 * a scalar type or an incomplete type `_`.
1793 */
1794 Type ParseType() {
1795 return WithSpan<Type>([&]() -> Type {
1796 auto tok = Peek();
1797
1798 if (tok->token_type == TokenType::kOpenParen) {
1799 auto tys =
1800 ParseSequence<relay::Type>(TokenType::kOpenParen, TokenType::kComma,
1801 TokenType::kCloseParen, [&]() { return ParseType(); });
1802 return relay::TupleType(tys);
1803 } else if (WhenMatch(TokenType::kFn)) {
1804 return ParseFunctionType();
1805 } else if (WhenMatch(TokenType::kIdentifier)) {
1806 auto id = tok.ToString();
1807 if (id == "Tensor") {
1808 Match(TokenType::kLSquare);
1809 auto shape = ParseShape();
1810 Match(TokenType::kComma);
1811 auto dtype_tok = Match(TokenType::kIdentifier);
1812 auto dtype = DataType(String2DLDataType(dtype_tok.ToString()));
1813 Match(TokenType::kRSquare);
1814 return TensorType(shape, dtype);
1815 } else {
1816 auto ty = tok.ToString();
1817 if (ty.rfind("int", 0) == 0 || ty.find("float", 0) == 0 || ty.find("uint", 0) == 0 ||
1818 ty.find("bool", 0) == 0) {
1819 // Need to do better error handling here.
1820 auto dtype = DataType(String2DLDataType(tok.ToString()));
1821 return TensorType({}, dtype);
1822 } else {
1823 return ParseNonPrimitiveType(tok);
1824 }
1825 }
1826 } else if (WhenMatch(TokenType::kUnderscore)) {
1827 return IncompleteType();
1828 } else {
1829 this->diag_ctx.EmitFatal(Diagnostic::Error(tok->span)
1830 << "failed to parse type found " << tok);
1831 return Type();
1832 }
1833 });
1834 }
1835
1836 template <typename R>
1837 R ConsumeWhitespace(std::function<R()> func) {
1838 auto old = this->ignore_whitespace;
1839 this->ignore_whitespace = true;
1840 while (tokens[pos]->token_type == TokenType::kWhitespace) {
1841 pos++;
1842 }
1843 auto res = func();
1844 this->ignore_whitespace = old;
1845 return res;
1846 }
1847
1848 Map<String, Array<ObjectRef>> ParseMetadata() {
1849 if (Peek()->token_type == TokenType::kMetadata) {
1850 return Match(TokenType::kMetadata).ToMetadata();
1851 } else {
1852 return Map<String, Array<ObjectRef>>();
1853 }
1854 }
1855
1856 /*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */
1857 void DisplayNextN(int n) {
1858 std::cout << "remaining tokens: " << std::endl;
1859 auto bound = std::min(pos + n, static_cast<int>(tokens.size()));
1860 for (int i = 0; i < bound - pos; i++) {
1861 std::cout << tokens[pos + i] << std::endl;
1862 }
1863 }
1864
1865 // A function for debugging the operator parser.
1866 void DebugStack(const std::vector<Expr>& exprs, const std::vector<Rule>& rules) {
1867 std::cout << "Expr Stack: ";
1868 for (auto expr : exprs) {
1869 std::cout << expr << ", ";
1870 }
1871
1872 std::cout << std::endl;
1873 std::cout << "Op Stack: ";
1874 for (auto rule : rules) {
1875 std::cout << rule.op << ", ";
1876 }
1877
1878 std::cout << std::endl;
1879 }
1880};
1881
1882Parser InitParser(const std::string& file_name, const std::string& file_content,
1883 const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
1884 VLOG(9) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size();
1885 SourceName src_name = SourceName::Get(file_name);
1886 Source source(src_name, file_content);
1887
1888 IRModule module;
1889 if (!init_module) {
1890 SourceMap source_map;
1891 module = IRModule({}, {}, {}, source_map);
1892 } else {
1893 module = init_module.value();
1894 }
1895
1896 module->source_map.Add(source);
1897
1898 auto diag_ctx = DiagnosticContext::Default(module);
1899 auto tokens_and_table = Tokenize(diag_ctx, source);
1900
1901 auto tokens = tokens_and_table.first;
1902 MetaTable meta_data_table = tokens_and_table.second.ToMetadata();
1903
1904 // Merge any entries in init_meta_table into anything captured in the #[metadata] section
1905 // of the file_content. Metadata references within file_content must use indexes which account
1906 // for this ordering.
1907 for (const auto& pair : init_meta_table) {
1908 Array<ObjectRef> items;
1909 if (meta_data_table.count(pair.first)) {
1910 items = meta_data_table[pair.first];
1911 }
1912 for (const auto& obj : pair.second) {
1913 items.push_back(obj);
1914 }
1915 meta_data_table.Set(pair.first, items);
1916 }
1917
1918 return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table));
1919}
1920
1921IRModule ParseModule(const std::string& file_name, const std::string& file_content,
1922 const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
1923 VLOG_CONTEXT << "ParseModule";
1924 VLOG(9) << "parsing and type-checking " << file_name;
1925 auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
1926 auto mod = parser.ParseModule();
1927 ICHECK(mod.defined()) << "The parser must return a non-null module.";
1928 // NB(@jroesch): it is very important that we render any errors before we proceed
1929 // if there were any errors which allow the parser to proceed we must render them
1930 // here.
1931 parser.diag_ctx.Render();
1932 auto infer_type = tvm::relay::transform::InferType();
1933 ICHECK(infer_type.defined()) << "The type inferencer must be non-null.";
1934 return infer_type(mod);
1935}
1936
1937Expr ParseExpr(const std::string& file_name, const std::string& file_content) {
1938 VLOG(9) << "ParseExpr";
1939 auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable());
1940 parser.ParseSemVer(false);
1941 parser.PushScope();
1942 auto expr = parser.ParseExpr();
1943 parser.Match(TokenType::kEndOfFile);
1944 // NB(@jroesch): it is very important that we render any errors before we proceed
1945 // if there were any errors which allow the parser to proceed we must render them
1946 // here.
1947 parser.diag_ctx.Render();
1948 return expr;
1949}
1950
1951TVM_REGISTER_GLOBAL("parser.ParseModuleInContext")
1952 .set_body_typed([](const std::string& file_name, const std::string& file_content,
1953 const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
1954 return ParseModule(file_name, file_content, init_module, init_meta_table);
1955 });
1956
1957TVM_REGISTER_GLOBAL("parser.ParseModule")
1958 .set_body_typed([](const std::string& file_name, const std::string& file_content) {
1959 return ParseModule(file_name, file_content);
1960 });
1961
1962TVM_REGISTER_GLOBAL("parser.ParseExpr")
1963 .set_body_typed([](tvm::String file_name, tvm::String file_content) {
1964 return ParseExpr(file_name, file_content);
1965 });
1966
1967/*!
1968 * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
1969 * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
1970 * modules constructed programaticaly rather than textually.
1971 */
1972Pass AnnotateSpans() {
1973 auto pass_func = [](const IRModule& mod, const PassContext& ctx) {
1974 String text = AsText(mod, /*show_meta_data=*/true);
1975 VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
1976 return ParseModule("GeneratedSource", text);
1977 };
1978 return CreateModulePass(pass_func, 0, "AnnotateSpans", {});
1979}
1980
1981TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans);
1982
1983} // namespace parser
1984} // namespace tvm
1985