1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file parser.cc
22 * \brief A parser for TVM IR.
23 */
24#include <tvm/ir/module.h>
25#include <tvm/node/reflection.h>
26#include <tvm/relay/adt.h>
27#include <tvm/relay/expr.h>
28#include <tvm/relay/function.h>
29#include <tvm/relay/parser.h>
30#include <tvm/relay/transform.h>
31#include <tvm/runtime/builtin_fp16.h>
32#include <tvm/runtime/logging.h>
33#include <tvm/runtime/object.h>
34#include <tvm/runtime/registry.h>
35#include <tvm/target/virtual_device.h>
36
37#include <fstream>
38
39#include "../../support/scalars.h"
40#include "./meta_ref.h"
41#include "./op_table.h"
42#include "./span_check.h"
43#include "./tokenizer.h"
44
45namespace tvm {
46namespace relay {
47
48/*! \brief The meta table maps from type key to a sequence of objects. */
49using MetaTable = Map<String, Array<ObjectRef>>;
50
51using tvm::transform::CreateModulePass;
52using tvm::transform::PassContext;
53
54/*! \brief A helper for passing around spans with data structures with
55 * no span field.
56 */
57template <typename T>
58struct Spanned {
59 T data;
60 Span span;
61
62 Spanned() = default;
63 Spanned(const Spanned<T>& other) = default;
64 Spanned(T data, Span span) : data(data), span(span) {}
65};
66
67/*! \brief A wrapper structure for capturing the result of parsing
68 * a global definition *before* we add it to the IRModule.
69 *
70 * This enables the parser to parse everything in one pass before
71 * constructing the IRModule.
72 */
73struct GlobalFunc {
74 GlobalVar global;
75 Function function;
76 GlobalFunc() : global(), function() {}
77 GlobalFunc(GlobalVar global, Function function) : global(global), function(function) {}
78 GlobalFunc(const GlobalFunc& gfunc) {
79 this->global = gfunc.global;
80 this->function = gfunc.function;
81 }
82};
83
84/*! \brief A wrapper structure for capturing all top-level definitions
85 * when parsing a module.
86 */
87struct Definitions {
88 /*! \brief The set of global functions. */
89 std::vector<GlobalFunc> funcs;
90 /*! \brief The set of type definitions. */
91 std::vector<TypeData> types;
92 // TODO(@jroesch): contain meta-table below
93};
94
95/*! \brief A structure representing the semantic versioning information
96 * for a Relay program.
97 */
98class SemVer {
99 public:
100 int major_version;
101 int minor_version;
102 int patch_version;
103
104 SemVer() : major_version(0), minor_version(0), patch_version(0) {}
105 SemVer(int major_version, int minor_version, int patch_version)
106 : major_version(major_version), minor_version(minor_version), patch_version(patch_version) {}
107 SemVer(const SemVer& other)
108 : major_version(other.major_version),
109 minor_version(other.minor_version),
110 patch_version(other.patch_version) {}
111};
112
113/*! \brief A simple wrapper around a mapping from raw string names
114 * to a TVM variable, type variable or other binder type.
115 */
116template <typename T>
117struct Scope {
118 /*! \brief The internal map. */
119 std::unordered_map<std::string, T> name_map;
120};
121
122/*! \brief A stack of scopes.
123 *
124 * In order to properly handle scoping we must maintain a stack of scopes.
125 *
126 * A stack allows users to write programs which contain repeated variable
127 * names and to properly handle both nested scopes and removal of variables
128 * when they go out of scope.
129 *
130 * This is the classic approach to lexical scoping.
131 */
132template <typename T>
133class ScopeStack {
134 private:
135 std::vector<Scope<T>> scope_stack;
136 std::unordered_map<std::string, T> free_vars;
137
138 public:
139 /*! \brief Adds a variable binding to the current scope. */
140 void Add(const std::string& name, const T& value) {
141 if (!this->scope_stack.size()) {
142 LOG(FATAL) << "internal issue";
143 }
144 this->scope_stack.back().name_map.insert({name, value});
145 }
146
147 void AddFreeVar(const std::string& name, const T& value) { free_vars.insert({name, value}); }
148
149 /*! \brief Looks up a variable name in the scope stack returning the matching variable
150 * in most recent scope. */
151 T Lookup(const std::string& name) {
152 for (auto scope = this->scope_stack.rbegin(); scope != this->scope_stack.rend(); ++scope) {
153 auto it = scope->name_map.find(name);
154 if (it != scope->name_map.end()) {
155 return it->second;
156 }
157 }
158
159 // Check if we bound a free variable declaration.
160 auto it = free_vars.find(name);
161 if (it != free_vars.end()) {
162 return it->second;
163 }
164
165 return T();
166 }
167
168 /*! \brief Adds a fresh scope. */
169 void PushStack() { this->scope_stack.push_back(Scope<T>()); }
170
171 /*! \brief Removes the most recent scope. */
172 void PopStack() { this->scope_stack.pop_back(); }
173};
174
175struct DuplicateKeyError : public Error {
176 explicit DuplicateKeyError(const std::string& msg) : Error(msg) {}
177};
178
179/*! \brief A table of interning strings as global function and type names. */
180template <typename T>
181struct InternTable {
182 /*! \brief The internal table mapping strings to a unique allocation. */
183 std::unordered_map<std::string, T> table;
184 DiagnosticContext* ctx;
185
186 /*! \brief Add the unique allocation. */
187 void Add(const std::string& name, const T& t) {
188 auto it = table.find(name);
189 if (it != table.end()) {
190 throw DuplicateKeyError("duplicate key name in intern table");
191 } else {
192 table.insert({name, t});
193 }
194 }
195
196 /*! \brief Return the unique allocation. */
197 Optional<T> Get(const std::string& name) const {
198 auto it = table.find(name);
199 if (it != table.end()) {
200 return Optional<T>(it->second);
201 } else {
202 return Optional<T>();
203 }
204 }
205};
206
207GlobalVar AddOrGet(InternTable<GlobalVar>* table, const std::string& name) {
208 auto var = table->Get(name);
209 if (var) {
210 return var.value();
211 } else {
212 auto gvar = GlobalVar(name);
213 table->Add(name, gvar);
214 return gvar;
215 }
216}
217
218GlobalTypeVar AddOrGet(InternTable<GlobalTypeVar>* table, const std::string& name,
219 TypeKind kind = TypeKind::kType) {
220 auto var = table->Get(name);
221 if (var) {
222 auto tvar = var.value();
223 TypeKind& tvar_kind = const_cast<TypeKind&>(tvar->kind);
224 tvar_kind = kind;
225 return tvar;
226 } else {
227 auto gvar = GlobalTypeVar(name, kind);
228 table->Add(name, gvar);
229 return gvar;
230 }
231}
232
233/*! \brief The parser class is the main interface to the parser.
234 * the parser is not currently exposed beyond this .cc file.
235 *
236 * The parser is initialized with a diagnostic context, an
237 * operator table, and a token stream.
238 *
239 * The rest of the internal state is used to map the human readable
240 * form to in-memory IR representation.
241 *
242 * The main entry point to the parser are a set of parsing methods
243 * such as `ParseModule` and `ParseExpr`.
244 *
245 * As with traditional recursive descent parsers the parsing methods
246 * are factored recursively just as one would do with a formal language
247 * grammar.
248 *
249 * You can view a recursive descent parser as a human friendly way to specify
250 * a state machine, and thus this factoring is necessary as the 'state' of this
251 * machine is the combination of the current parsing method and the next token.
252 *
253 * Parsing proceeds by matching a token and then dispatching to the appropriate
254 * method to parse the next tokens in the stream.
255 *
256 * For example if we are parsing a type and encounter a "Tensor" token we switch
257 * into a mode for parsing `[`, a shape, a comma, a data type and then a `]`.
258 *
259 * Certain matches like this are unambiguous and proceed in a straight line fashion
260 * once the initial token is found. Other parsing is more complex and requires some
261 * tricks to correctly parse.
262 *
263 * For example when we find a '(' in an expression context, it may be part of
264 * a tuple, the arguments to a call, or a parenthesized expression. The below code
265 * disambiguate these cases by factoring expression parsing into a series of methods
266 * which encode the parsing context and thus how to interpret the parenthesis.
267 *
268 * For more information one should be able to read the code in order starting with
269 * `ParseModule` or `ParseExpr`.
270 */
271class Parser {
272 public:
273 /*! \brief The version that the parser is parsing. */
274 SemVer version;
275
276 /*! \brief The IRModule we are building. */
277 IRModule module;
278
279 /*! \brief The diagnostic context used for error reporting. */
280 DiagnosticContext diag_ctx;
281
282 const Source& source;
283
284 /*! \brief The current position in the token stream. */
285 int pos;
286
287 /*! \brief The token stream for the parser. */
288 std::vector<Token> tokens;
289
290 /*! \brief The configured operator table. */
291 OperatorTable op_table;
292
293 /*! \brief Configure the whitespace mode, right now we ignore all whitespace. */
294 bool ignore_whitespace;
295
296 /*! \brief A global mapping for GlobalVar. */
297 InternTable<GlobalVar> global_names;
298
299 /*! \brief A global mapping for type definitions. */
300 InternTable<GlobalTypeVar> type_names;
301
302 /*! \brief A global mapping for constructor names. */
303 InternTable<Constructor> ctors;
304
305 /*! \brief A mapping from graph variable to expression, i.e., `%0 = expr`. */
306 std::unordered_map<int, Expr> graph_ctx;
307
308 /*! \brief The set of type scopes used for generics. */
309 ScopeStack<TypeVar> type_scopes;
310
311 /*! \brief The set of expression scopes used for lexical scope. */
312 ScopeStack<Var> expr_scopes;
313
314 /*! \brief The metadata section. */
315 MetaTable meta_table;
316
317 Parser(IRModule module, DiagnosticContext ctx, const Source& source, std::vector<Token> tokens,
318 OperatorTable op_table, MetaTable table)
319 : module(module),
320 diag_ctx(ctx),
321 source(source),
322 pos(0),
323 tokens(tokens),
324 op_table(op_table),
325 ignore_whitespace(true),
326 meta_table(table) {
327 InitializeGlobals();
328 InitializeTypeDefs();
329 }
330
331 /*! If we are parsing into a module with previously loaded data types we need to
332 * map constructor names and variable names in the global tables.
333 */
334 void InitializeTypeDefs() {
335 for (auto pair : this->module->type_definitions) {
336 type_names.Add(pair.first->name_hint, pair.first);
337 for (auto ctor : pair.second->constructors) {
338 ctors.Add(ctor->name_hint, ctor);
339 }
340 }
341 }
342
343 void InitializeGlobals() {
344 for (auto pair : this->module->functions) {
345 global_names.Add(pair.first->name_hint, pair.first);
346 }
347 }
348
349 /*! \brief Examine the next token in the stream, the current parser is configured to be
350 * whitespace insensitive so we will skip all whitespace or comment tokens. */
351 Token Peek() {
352 // For now we ignore all whitespace tokens and comments.
353 // We can tweak this behavior later to enable white space sensitivity in the parser.
354 while (pos < static_cast<int64_t>(tokens.size()) && ignore_whitespace &&
355 (tokens.at(pos)->token_type == TokenType::kWhitespace ||
356 tokens.at(pos)->token_type == TokenType::kNewline ||
357 tokens.at(pos)->token_type == TokenType::kLineComment ||
358 tokens.at(pos)->token_type == TokenType::kComment)) {
359 pos++;
360 }
361
362 if (pos < static_cast<int64_t>(tokens.size())) {
363 return Token(this->tokens.at(pos));
364 } else {
365 return Token::Null();
366 }
367 }
368
369 /*! \brief Lookahead by N tokens.
370 * \param n The number of tokens to lookahead.
371 * \return The Nth token.
372 */
373 Token Lookahead(int n) {
374 ICHECK_GE(n, 1) << "lookahead is only valid when n >= 1";
375
376 // We intend to skip n - 1 tokens, then return the nth.
377 auto old_pos = pos;
378 for (int i = 0; i < n - 1; i++) {
379 Peek();
380 pos++;
381 }
382
383 auto tok = Peek();
384 pos = old_pos;
385 return tok;
386 }
387
388 /*! \brief Consume a token, this method is the lowest level way to consume a token
389 * and will not ignore white space or look ahead in anyway.
390 *
391 * /param token_type The token type to match.
392 */
393 void Consume(const TokenType& token_type) {
394 if (tokens[pos]->token_type != token_type) {
395 this->diag_ctx.EmitFatal(Diagnostic::Error(tokens[pos]->span)
396 << "expected a " << Pretty(token_type) << " found "
397 << Pretty(Peek()->token_type));
398 }
399 pos++;
400 }
401
402 /*! Match a token in the stream, this will first invoke Peek, ignoring tokens such
403 * as whitespace or comments returning the first meaningful token.
404 *
405 * We then try and consume the requested token, this will trigger an error if the
406 * current token does not match the token_type.
407 */
408 Token Match(const TokenType& token_type) {
409 auto tok = Peek();
410 Consume(token_type);
411 return tok;
412 }
413
414 /*! Conditionally consume a token when it matches, this will never trigger an error
415 * as we guard against consuming the token before we do.
416 *
417 * Useful for matching optional tokens, effectively looksahead by one.
418 */
419 bool WhenMatch(const TokenType& token_type) {
420 VLOG(9) << "Parser::WhenMatch: Peek() == " << Peek();
421 if (Peek()->token_type == token_type) {
422 Consume(token_type);
423 return true;
424 } else {
425 return false;
426 }
427 }
428
429 /* \brief Add a graph binding to the parsing context
430 *
431 * For example if we parse %0 = add(...), map 0 -> add(...), etc.
432 */
433 void AddGraphBinding(const Token& token, const Expr& expr) {
434 auto graph_no = token.ToNumber();
435 this->graph_ctx.insert({graph_no, expr});
436 }
437
438 /* \brief Lookup a previously bound graph variable.
439 *
440 * Note: we take tokens in all lookup methods so that we
441 * that we can do error reporting based on token location.
442 */
443 Expr LookupGraphBinding(const Token& token) {
444 auto graph_no = token.ToNumber();
445 auto it = this->graph_ctx.find(graph_no);
446 if (it != this->graph_ctx.end()) {
447 return it->second;
448 } else {
449 LOG(FATAL) << "Local variable %" << graph_no << " has not yet been defined";
450 throw;
451 }
452 }
453
454 /*! \brief Bind a local variable in the expression scope.
455 *
456 * "x" -> Var("x"), these are needed to map from the raw string names
457 * to unique variable nodes.
458 * If a virtual device is specified, sets the virtual device of the variable.
459 */
460 Var BindVar(const std::string& name, const relay::Type& type_annotation,
461 Optional<VirtualDevice> virtual_device = Optional<VirtualDevice>()) {
462 auto var = Var(name, type_annotation);
463 var->virtual_device_ = virtual_device.value_or(VirtualDevice::FullyUnconstrained());
464 VLOG(1) << "Binding var named " << name << " to variable node " << PrettyPrint(var);
465 this->expr_scopes.Add(name, var);
466 return var;
467 }
468
469 /*! \brief Bind a local variable in the expression scope.
470 *
471 * "x" -> Var("x"), these are needed to map from the raw string names
472 * to unique variable nodes.
473 */
474 Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) {
475 auto var = Var(name, type_annotation);
476 this->expr_scopes.AddFreeVar(name, var);
477 return var;
478 }
479
480 /*! \brief Bind a type variable in the type scope.
481 *
482 * "A" -> TypeVar("A", ...), these are needed to map from raw string names
483 * to unique type variable nodes.
484 */
485 TypeVar BindTypeVar(const std::string& name, const TypeKind type_kind) {
486 auto type_var = TypeVar(name, type_kind);
487 this->type_scopes.Add(name, type_var);
488 return type_var;
489 }
490
491 /*! \brief Lookup a variable in the expression scope.
492 *
493 * Note: all lookup methods take tokens intentionally for error reporting information.
494 */
495 Var LookupLocal(const Token& local) {
496 auto var = this->expr_scopes.Lookup(local.ToString());
497 if (!var.defined()) {
498 diag_ctx.Emit(Diagnostic::Error(local->span)
499 << "this local variable has not been previously declared");
500 }
501 return var;
502 }
503
504 /*! \brief Lookup a variable in the type scope.
505 *
506 * Note: all lookup methods take tokens intentionally for error reporting information.
507 */
508 TypeVar LookupTypeVar(const Token& ident) {
509 auto var = this->type_scopes.Lookup(ident.ToString());
510 return var;
511 }
512
513 /*! \brief Add an expression scope to the scope stack. */
514 void PushScope() { this->expr_scopes.PushStack(); }
515
516 /*! \brief Remove N expression scopes from the scope stack. */
517 void PopScopes(int n) {
518 for (int i = 0; i < n; i++) {
519 this->expr_scopes.PopStack();
520 }
521 }
522
523 /*! \brief Add an type scope to the scope stack. */
524 void PushTypeScope() { this->type_scopes.PushStack(); }
525
526 /*! \brief Remove N type scopes from the scope stack. */
527 void PopTypeScopes(int n) {
528 for (int i = 0; i < n; i++) {
529 this->type_scopes.PopStack();
530 }
531 }
532
533 /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */
534 NDArray NumberToNDArray(const Token& token) {
535 if (token->token_type == TokenType::kInteger) {
536 return support::IntImmToNDArray(Downcast<tvm::IntImm>(token->data));
537 } else if (token->token_type == TokenType::kFloat) {
538 return support::FloatImmToNDArray(Downcast<tvm::FloatImm>(token->data));
539 } else {
540 LOG(FATAL) << "internal error: should only call this function on numeric tokens";
541 }
542 }
543
544 [[noreturn]] void ParseError(const Token& token, const std::string& msg) {
545 throw std::runtime_error(msg);
546 }
547
548 /*! \brief A parsing helper for a bracketed expression <start> <parser> <stop>. */
549 template <typename R>
550 R Bracket(TokenType open, TokenType close, std::function<R()> parser) {
551 Match(open);
552 R result = parser();
553 Match(close);
554 return result;
555 }
556
557 /*! \brief Parse `(` parser() `)`. */
558 template <typename R>
559 R Parens(std::function<R()> parser) {
560 return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser);
561 }
562
563 /*! \brief Parse `{` parser() `}`. */
564 template <typename R>
565 R Block(std::function<R()> parser) {
566 return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser);
567 }
568
569 template <typename R>
570 R WithSpan(std::function<R()> parser) {
571 auto start_span = Peek()->span;
572 VLOG(9) << "WithSpan: start_span = " << start_span;
573 R ast = parser();
574 if (ast.defined()) {
575 // The token at the head of the stream is now 1 past where we parsed. So we find its start
576 // position as its start and end, so that when we merge we only grow the spanned region
577 // to the start of the current stream.
578 auto span_pos = pos - 1;
579 while ((tokens.at(span_pos)->token_type == TokenType::kWhitespace ||
580 tokens.at(span_pos)->token_type == TokenType::kNewline ||
581 tokens.at(span_pos)->token_type == TokenType::kLineComment ||
582 tokens.at(span_pos)->token_type == TokenType::kComment)) {
583 span_pos--;
584 }
585 auto end_token = tokens.at(span_pos);
586 VLOG(9) << "WithSpan: end_span = " << end_token->span;
587 ast->span = start_span.Merge(end_token->span);
588 }
589 return ast;
590 }
591
592 struct MetaRef {
593 std::string type_key;
594 uint64_t node_index;
595 Span span;
596 MetaRef(std::string type_key, uint64_t node_index, Span span)
597 : type_key(type_key), node_index(node_index), span(span) {}
598 };
599
600 MetaRef MetaRefFromToken(const Token& tok) {
601 Call ref = Downcast<Call>(tok->data);
602 auto attrs = ref->attrs.as<MetaRefAttrs>();
603 auto type_key = attrs->node_type_key;
604 auto index = attrs->node_index;
605 return MetaRef(type_key, index, ref->span);
606 }
607
608 /*! \brief Parse a meta reference of the form `meta[type_key][node_index]`.
609 * For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]`
610 * the second, and so on.
611 */
612 ObjectRef ParseMetaRef() {
613 auto meta_ref_tok = Match(TokenType::kMetaReference);
614 auto meta_ref = MetaRefFromToken(meta_ref_tok);
615 auto it = this->meta_table.find(meta_ref.type_key);
616 if (it != this->meta_table.end()) {
617 auto nodes = (*it).second;
618 if (meta_ref.node_index < nodes.size()) {
619 return nodes[meta_ref.node_index];
620 } else {
621 this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span)
622 << "the node index `" << meta_ref.node_index
623 << "` is out of bounds for `" << meta_ref.type_key << "`");
624 return ObjectRef();
625 }
626 } else {
627 this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span)
628 << "no entry in the meta table for `" << meta_ref.type_key << "`");
629 return ObjectRef();
630 }
631 }
632 /*! \brief Parses a sequence beginning with a start token, separated by a seperator token, and
633 * ending with a stop token.
634 *
635 * The simple form being <start> (<parse()> <seperator>)* <stop>.
636 *
637 * This also provides a fourth argument which is allowed to run when the sequence which matches
638 * the inner sequence can not proceed.
639 *
640 * This is useful for parsing things like attributes which don't match the standard expression
641 * parsers but are contained within the stop token.
642 */
643 template <typename T>
644 Array<T> ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function<T()> parse,
645 std::function<bool()> before_stop = nullptr) {
646 VLOG(9) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep)
647 << " stop=" << ToString(stop);
648 Match(start);
649
650 // This is for the empty arguments list case, if we have <start> <leftovers> <stop> token stream
651 // we must parse leftovers, then match a stop token.
652 if (before_stop) {
653 auto did_parse = before_stop();
654 if (did_parse) {
655 Match(stop);
656 return {};
657 }
658 }
659
660 // This is the case in which we find an empty arguments lists and no leftovers.
661 if (WhenMatch(stop)) {
662 return Array<T>();
663 } else {
664 VLOG(9) << "Parser::ParseSequence: parse first";
665 auto data = parse();
666 Array<T> elements = {data};
667
668 if (WhenMatch(stop)) {
669 return elements;
670 // parse '( expr ',' * ')'
671 } else if (WhenMatch(sep)) {
672 while (true) {
673 VLOG(9) << "Parser::ParseSequence: parse element";
674 if (WhenMatch(stop)) {
675 break;
676 } else {
677 // If before stop is
678 if (before_stop) {
679 auto did_parse = before_stop();
680 if (did_parse) {
681 Match(stop);
682 return elements;
683 }
684 }
685 auto data = parse();
686 WhenMatch(sep);
687 elements.push_back(data);
688 }
689 }
690 return elements;
691 } else {
692 auto next = Peek();
693 this->diag_ctx.EmitFatal(Diagnostic::Error(next->span)
694 << "expected a " << Pretty(stop) << " found "
695 << Pretty(next->token_type));
696 return Array<T>(nullptr);
697 }
698 }
699 }
700
701 /*! \brief Parse a full IRModule. */
702 IRModule ParseModule() {
703 // Parse the semver header at the top of the module.
704 this->version = ParseSemVer();
705 // Parse the definitions.
706 auto defs = ParseDefinitions();
707 // Parse the metadata section at the end.
708 auto metadata = ParseMetadata();
709
710 Match(TokenType::kEndOfFile);
711
712 for (auto type_def : defs.types) {
713 module->AddTypeDef(type_def->header, type_def);
714 }
715
716 for (auto func : defs.funcs) {
717 module->Add(func.global, func.function, true);
718 }
719
720 return module;
721 }
722
723 /*! \brief Parse the semantic versioning header. */
724 SemVer ParseSemVer(bool required = true) {
725 if (Peek()->token_type == TokenType::kVersion) {
726 auto version = Match(TokenType::kVersion);
727 // TODO(@jroesch): we currently only support 0.0.5.
728 if (version.ToString() != "\"0.0.5\"") {
729 this->diag_ctx.Emit(Diagnostic::Error(version->span)
730 << "invalid semantic version `" << version.ToString() << "`");
731 }
732 } else if (required) {
733 this->diag_ctx.Emit(Diagnostic::Error(Peek()->span)
734 << "expected text format semantic version, found a "
735 << PrettyPrint(Peek()));
736
737 this->diag_ctx.Emit(Diagnostic::Help(Peek()->span)
738 << "you can annotate it as #[version = \"0.0.5\"]");
739 }
740 return SemVer(0, 0, 5);
741 }
742
743 /*! \brief Parse zero or more Relay definitions. */
744 Definitions ParseDefinitions() {
745 Definitions defs;
746
747 while (true) {
748 auto next = Peek();
749 switch (next->token_type) {
750 case TokenType::kDefn: {
751 Consume(TokenType::kDefn);
752 auto global_tok = Match(TokenType::kGlobal);
753 auto global_name = global_tok.ToString();
754 auto global = AddOrGet(&global_names, global_name);
755 auto func = WithSpan<relay::Function>([&]() { return ParseFunctionDef(); });
756 ICHECK(func->span.defined()) << "spans must be set in parser";
757 defs.funcs.push_back(GlobalFunc(global, func));
758 continue;
759 }
760 case TokenType::kTypeDef: {
761 defs.types.push_back(ParseTypeDef());
762 continue;
763 }
764 case TokenType::kExtern: {
765 Consume(TokenType::kExtern);
766 auto type_def = ParseTypeDef();
767 if (type_def->constructors.size()) {
768 diag_ctx.Emit(Diagnostic::Error(next->span)
769 << "an external type may not have any constructors");
770 }
771 defs.types.push_back(type_def);
772 }
773 default:
774 return defs;
775 }
776 }
777 }
778
779 /*! \brief Parse zero or more Relay type definitions. */
780 TypeData ParseTypeDef() {
781 // Match the `type` keyword.
782 Match(TokenType::kTypeDef);
783 // Parse the type's identifier.
784 auto type_tok = Match(TokenType::kIdentifier);
785 auto type_id = type_tok.ToString();
786 auto type_global = AddOrGet(&type_names, type_id, TypeKind::kAdtHandle);
787
788 Array<TypeVar> generics;
789
790 bool should_pop = false;
791 if (Peek()->token_type == TokenType::kLSquare) {
792 // If we have generics we need to add a type scope.
793 PushTypeScope();
794 should_pop = true;
795 generics = ParseSequence<TypeVar>(
796 TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() {
797 auto type_var_name = Match(TokenType::kIdentifier).ToString();
798 return BindTypeVar(type_var_name, TypeKind::kType);
799 });
800 }
801
802 Array<tvm::Constructor> ctors;
803 if (Peek()->token_type == TokenType::kLCurly) {
804 // Parse the list of constructors.
805 ctors = ParseSequence<tvm::Constructor>(
806 TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() {
807 // First match the name of the constructor.
808 auto ctor_tok = Match(TokenType::kIdentifier);
809 auto ctor_name = ctor_tok.ToString();
810
811 Constructor ctor;
812 // Match the optional field list.
813 if (Peek()->token_type != TokenType::kOpenParen) {
814 ctor = tvm::Constructor(ctor_name, {}, type_global);
815 } else {
816 auto arg_types =
817 ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma,
818 TokenType::kCloseParen, [&]() { return ParseType(); });
819 ctor = tvm::Constructor(ctor_name, arg_types, type_global);
820 }
821
822 ICHECK(ctor.defined());
823
824 try {
825 this->ctors.Add(ctor_name, ctor);
826 } catch (const DuplicateKeyError& e) {
827 this->diag_ctx.EmitFatal(Diagnostic::Error(ctor_tok->span)
828 << "a constructor with the name "
829 << "`" << ctor_name << "` "
830 << "was previously defined");
831 }
832
833 return ctor;
834 });
835 }
836
837 // Now pop the type scope.
838 if (should_pop) {
839 PopTypeScopes(1);
840 }
841
842 return TypeData(type_global, generics, ctors);
843 }
844
845 std::string HackTokensAsString(int n) {
846 std::stringstream key;
847 n = std::min(static_cast<int>(tokens.size() - pos), n);
848 for (int i = 0; i < n; i++) {
849 key << ToString(tokens.at(pos + i)->token_type);
850 }
851 return key.str();
852 }
853
854 std::vector<Rule> ParseOp() {
855 std::vector<Rule> matched;
856 Peek();
857 for (int i = 4; i > 0; i--) {
858 auto key = HackTokensAsString(i);
859 auto it = this->op_table.this_is_a_hack.find(key);
860 if (it != this->op_table.this_is_a_hack.end()) {
861 pos = pos + i;
862 matched.push_back(it->second);
863 }
864 }
865
866 return matched;
867 }
868
869 /*! \brief Parse a single Relay expression. */
870 Expr ParseExpr() {
871 VLOG(9) << "Parser::ParseExpr";
872 return WithSpan<Expr>([this] {
873 std::vector<Expr> exprs;
874
875 while (true) {
876 VLOG(9) << "Parser::ParseExpr: parsing a single expression";
877 auto next = Peek();
878 switch (next->token_type) {
879 // For graph or let, match first rhs, then invoke ParseBindingExpr
880 // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue
881 case TokenType::kLCurly: {
882 // NB: Might need to optimize to remove deep recursion.
883 // Stack should only grow proportionally to the number of
884 // nested scopes.
885 // Parses `{` expression `}`.
886 auto block = WithSpan<Expr>([&]() {
887 return Bracket<Expr>(TokenType::kLCurly, TokenType::kRCurly, [&]() {
888 PushScope();
889 auto expr = ParseExpr();
890 PopScopes(1);
891 return expr;
892 });
893 });
894 exprs.push_back(block);
895 break;
896 }
897 case TokenType::kFreeVar: {
898 Consume(TokenType::kFreeVar);
899 auto var_token = Match(TokenType::kLocal);
900
901 Type type;
902 if (WhenMatch(TokenType::kColon)) {
903 type = ParseType();
904 } else {
905 type = IncompleteType();
906 }
907
908 BindFreeVar(var_token.ToString(), type);
909 break;
910 }
911 // Parses `let ...`;
912 case TokenType::kLet:
913 exprs.push_back(ParseBindingExpr());
914 break;
915 case TokenType::kMatch:
916 case TokenType::kPartialMatch: {
917 bool is_total = next->token_type == TokenType::kMatch;
918 Consume(next->token_type);
919 exprs.push_back(ParseMatch(is_total));
920 break;
921 }
922
923 // %x ...
924 case TokenType::kGraph:
925 if (Lookahead(2)->token_type == TokenType::kEqual) {
926 exprs.push_back(ParseBindingExpr());
927 break;
928 }
929 // intentional fall through here.
930 default: {
931 exprs.push_back(ParseExprBinOp());
932 break;
933 }
934 }
935
936 if (!WhenMatch(TokenType::kSemicolon)) {
937 break;
938 }
939 }
940
941 ICHECK_GE(exprs.size(), 1);
942
943 if (exprs.size() == 1) {
944 // ICHECK(exprs[0].defined() && exprs[0]->span.defined())
945 // << "parser must set expression spans.\n"
946 // << exprs[0];
947 return exprs[0];
948 } else {
949 auto body = exprs.back();
950 exprs.pop_back();
951 while (exprs.size()) {
952 auto value = exprs.back();
953 ICHECK(value->span.defined()) << "parser must set expression spans.";
954 exprs.pop_back();
955 body = relay::Let(Var("", IncompleteType()), value, body, value->span.Merge(body->span));
956 }
957 ICHECK(body->span.defined()) << "parser must set expression spans.";
958 return body;
959 }
960 });
961 }
962
963 /*! \brief Parse a "binding expression"; an expression where
964 * a graph or let variable is bound.
965 *
966 * In order to avoid stack overflow this is implemented in a special
967 * iterative way to keep stack depth constant in a long chain of bindings.
968 */
969 Expr ParseBindingExpr() {
970 // We use a loop here so that the stack depth
971 // does not grow linearly with a sequence of
972 // graph or let bindings.
973 //
974 // Assuming we start at call depth k, we will
975 // enter k + c call frames to parse the RHS
976 // of the bindings where `c` is the depth
977 // of recursion needed by RHS.
978 //
979 // If RHS is a call expresssion the c=1.
980 //
981 // Once we have parsed the RHS we will be
982 // back at depth K, and will return to
983 // this loop header to parse another
984 // graph or let binding.
985 //
986 // This ensures for n sequential bindings
987 // the call depth will be the same before
988 // and after parsing the n bindings.
989 VLOG(9) << "Parser::ParseBindingExpr";
990 std::vector<std::tuple<Var, Expr, Span>> bindings;
991 int scopes = 0;
992
993 while (true) {
994 auto next = Peek();
995 if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) {
996 Match(TokenType::kGraph);
997 Match(TokenType::kEqual);
998 auto val = this->ParseExprBinOp();
999 Match(TokenType::kSemicolon);
1000 AddGraphBinding(next, val);
1001 } else if (next->token_type == TokenType::kLet) {
1002 auto span = next->span;
1003 // Parse the 'let'.
1004 Consume(TokenType::kLet);
1005
1006 // Parse the local '%<id>'.
1007 auto local_tok = Match(TokenType::kLocal);
1008 auto string = local_tok.ToString();
1009
1010 // Parse the optional type annotation (':' <type>).
1011 Type type;
1012 if (WhenMatch(TokenType::kColon)) {
1013 type = ParseType();
1014 }
1015
1016 auto var = BindVar(string, type);
1017
1018 // Parse the '=';
1019 Match(TokenType::kEqual);
1020
1021 // Parse the body, and the ';'.
1022 auto val = this->ParseExprBinOp();
1023 Consume(TokenType::kSemicolon);
1024
1025 // Add the bindings to the local data structure.
1026 std::tuple<relay::Var, relay::Expr, Span> tuple(var, val, span);
1027 bindings.push_back(tuple);
1028 scopes++;
1029 PushScope();
1030 } else {
1031 // This is the only case we will increase the stack
1032 // depth.
1033 //
1034 // If we parse a program which is a sequence of N bindings
1035 // followed by a single body expression we will end up with
1036 // a call depth of 3, the first call to ParseExpr, then
1037 // ParseBindingExpr, then finally ParseExpr once more.
1038
1039 auto body = this->ParseExpr();
1040
1041 // Remove the same number of scopes we added.
1042 PopScopes(scopes);
1043
1044 if (bindings.size() == 0) {
1045 return body;
1046 } else {
1047 // We can now build the let binding up backwards.
1048 for (auto binding = bindings.rbegin(); binding != bindings.rend(); binding++) {
1049 auto span = body->span.Merge(std::get<2>(*binding));
1050 body = relay::Let(std::get<0>(*binding), std::get<1>(*binding), body, span);
1051 }
1052 return body;
1053 }
1054 }
1055 }
1056 }
1057
1058 /*! Parse a function definition without a leading keyword or identifier.
1059 *
1060 * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }.
1061 */
1062 Function ParseFunctionDef() {
1063 VLOG(9) << "Parser::ParseFunctionDef";
1064 return WithSpan<Function>([&]() {
1065 PushScope();
1066 PushTypeScope();
1067
1068 Array<TypeVar> generics;
1069 if (Peek()->token_type == TokenType::kLSquare) {
1070 generics = ParseSequence<TypeVar>(
1071 TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() {
1072 auto type_var_name = Match(TokenType::kIdentifier).ToString();
1073 return BindTypeVar(type_var_name, TypeKind::kType);
1074 });
1075 }
1076
1077 Map<String, ObjectRef> raw_attrs;
1078
1079 auto params = ParseSequence<Var>(
1080 TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
1081 [&]() {
1082 auto token = Match(TokenType::kLocal);
1083 auto string = token.ToString();
1084
1085 // The fake attributes where the virtual device is specified.
1086 VirtualDevice virtual_device;
1087 if (WhenMatch(TokenType::kLCurly)) {
1088 Map<String, ObjectRef> fake_attrs = ParseAttrs();
1089 VLOG(9) << "Fake attributes for function parameter: " << fake_attrs;
1090 Match(TokenType::kRCurly);
1091 if (fake_attrs.size() == 1 && fake_attrs.count(kVirtualDevice)) {
1092 ICHECK(fake_attrs[kVirtualDevice].as<VirtualDeviceNode>())
1093 << "Expected the " << kVirtualDevice
1094 << " to have type VirtualDeviceNode, but got " << virtual_device->GetTypeKey();
1095 virtual_device = Downcast<VirtualDevice>(fake_attrs[kVirtualDevice]);
1096 }
1097 }
1098
1099 Type type;
1100 if (WhenMatch(TokenType::kColon)) {
1101 type = ParseType();
1102 }
1103 return BindVar(string, type, virtual_device);
1104 },
1105 [&] {
1106 auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
1107 auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual;
1108
1109 if (is_ident && next_is_equal) {
1110 raw_attrs = ParseAttrs();
1111 return true;
1112 }
1113
1114 return false;
1115 });
1116
1117 Type ret_type;
1118 if (WhenMatch(TokenType::kMinus)) {
1119 Match(TokenType::kRAngle);
1120 ret_type = ParseType();
1121 }
1122
1123 auto body = Block<Expr>([&]() { return ParseExpr(); });
1124
1125 PopTypeScopes(1);
1126 PopScopes(1);
1127
1128 // TODO(@jroesch): attributes should never be null, they should always be empty.
1129 if (raw_attrs.size()) {
1130 // Promote kVirtualDevice to first-class
1131 if (raw_attrs.count(kVirtualDevice)) {
1132 ObjectRef vid = raw_attrs.at(kVirtualDevice);
1133 ICHECK(vid.as<VirtualDeviceNode>())
1134 << "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got "
1135 << vid->GetTypeKey();
1136
1137 DictAttrs attrs;
1138 // Don't fill the raw_attrs in if there's nothing other than kVirtualDevice in the
1139 // attributes
1140 if (raw_attrs.size() > 1) {
1141 raw_attrs.erase(kVirtualDevice);
1142 attrs = DictAttrs(raw_attrs);
1143 }
1144 Function func = relay::Function(params, body, ret_type, generics, attrs);
1145 func->virtual_device_ = vid;
1146 return func;
1147 } else {
1148 return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
1149 }
1150 } else {
1151 return relay::Function(params, body, ret_type, generics, tvm::DictAttrs());
1152 }
1153 });
1154 }
1155
1156 /*! \brief Parse an if-expression. */
1157 Expr ParseIf() {
1158 return WithSpan<Expr>([&]() {
1159 VLOG(9) << "Parser::ParseIf";
1160 Consume(TokenType::kIf);
1161
1162 auto guard = WithSpan<Expr>([&] { return Parens<Expr>([&] { return ParseExpr(); }); });
1163
1164 auto true_branch = Block<Expr>([&] {
1165 this->PushScope();
1166 auto expr = ParseExpr();
1167 this->PopScopes(1);
1168 return expr;
1169 });
1170
1171 Match(TokenType::kElse);
1172
1173 auto false_branch = Block<Expr>([&] {
1174 this->PushScope();
1175 auto expr = ParseExpr();
1176 this->PopScopes(1);
1177 return expr;
1178 });
1179
1180 return relay::If(guard, true_branch, false_branch);
1181 });
1182 }
1183
1184 /* This factors parsing a list of patterns for both tuples, and constructors. */
1185 Array<Pattern> ParsePatternList() {
1186 return ParseSequence<Pattern>(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
1187 [&] { return ParsePattern(); });
1188 }
1189
1190 /*! \brief Parses a pattern for a match expression.
1191 *
1192 * A pattern is either a wildcard `_`, a local `%name`,
1193 * a constructor `C(p1, ..., pn)` or tuple `(p1, ..., pn).
1194 *
1195 * This function recursively parses a pattern.
1196 */
1197 Pattern ParsePattern() {
1198 VLOG(9) << "Parser::ParsePattern";
1199 auto next = Peek();
1200 switch (next->token_type) {
1201 case TokenType::kUnderscore: {
1202 Match(TokenType::kUnderscore);
1203 return PatternWildcard();
1204 }
1205 case TokenType::kLocal: {
1206 auto id = Match(TokenType::kLocal);
1207 Type type_annotation;
1208 if (WhenMatch(TokenType::kColon)) {
1209 type_annotation = ParseType();
1210 }
1211 auto var = BindVar(id.ToString(), type_annotation);
1212 return PatternVar(var);
1213 }
1214 case TokenType::kIdentifier: {
1215 auto id = Match(TokenType::kIdentifier);
1216 auto ctor = ctors.Get(id.ToString());
1217 if (!ctor) {
1218 diag_ctx.EmitFatal(
1219 // TODO(@jroesch): split into error and help
1220 // deal with multiple rendering
1221 Diagnostic::Error(id->span)
1222 << "undefined constructor name `" << id.ToString()
1223 << "`, perhaps you intended to write a"
1224 << "pattern variable, considering changing this to `%" << id.ToString() << "`");
1225 }
1226 if (Peek()->token_type == TokenType::kOpenParen) {
1227 auto fields = ParsePatternList();
1228 return PatternConstructor(ctor.value(), fields);
1229 } else {
1230 return PatternConstructor(ctor.value(), {});
1231 }
1232 }
1233 default:
1234 return PatternTuple(ParsePatternList());
1235 }
1236 }
1237
1238 Clause ParseMatchArm() {
1239 PushScope();
1240 auto pattern = ParsePattern();
1241 Match(TokenType::kEqual);
1242 Consume(TokenType::kRAngle);
1243 auto expr = ParseExpr();
1244 PopScopes(1);
1245 return Clause(pattern, expr);
1246 }
1247
1248 Expr ParseMatch(bool is_total) {
1249 return WithSpan<Expr>([&]() {
1250 Expr scrutinee = ParseAtomicExpr();
1251
1252 Array<Clause> clauses =
1253 ParseSequence<Clause>(TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly,
1254 [&] { return ParseMatchArm(); });
1255
1256 return relay::Match(scrutinee, clauses, is_total);
1257 });
1258 }
1259
1260 Expr ParseExprBinOp() {
1261 VLOG(9) << "Parser::ParseExprBinOp";
1262 return WithSpan<Expr>([this] {
1263 // We must parse at least one expression, the default
1264 // case is that there is no operator and we will fall
1265 // through.
1266 std::vector<Expr> exprs;
1267 Expr expr = WithSpan<Expr>([this] { return ParseCallExpr(); });
1268
1269 exprs.push_back(expr);
1270
1271 // Now we parse an optional op.
1272 std::vector<Rule> ops;
1273
1274 // We will now parse 0 or more operator occurrences.
1275 while (true) {
1276 auto opt_op = ParseOp();
1277
1278 // If we didn't parse one we done.
1279 if (opt_op.size() == 0) {
1280 break;
1281 }
1282
1283 // Read the operation we parsed;
1284 auto op = opt_op[0];
1285
1286 Expr right = WithSpan<Expr>([this] { return ParseCallExpr(); });
1287 ICHECK(right->span.defined());
1288
1289 // If the operator stack is empty
1290 // we parse an operator and expression
1291 // and push them to stacks, then
1292 // continue.
1293 if (ops.size() == 0) {
1294 ops.push_back(op);
1295 exprs.push_back(right);
1296 continue;
1297 }
1298
1299 if (op.precedence > ops.back().precedence ||
1300 (op.precedence == ops.back().precedence && op.left_assoc == false)) {
1301 ops.push_back(op);
1302 exprs.push_back(right);
1303 continue;
1304 }
1305
1306 while (ops.size() && (op.precedence < ops.back().precedence ||
1307 (op.precedence == ops.back().precedence && op.left_assoc == true))) {
1308 Rule new_op = ops.back();
1309 ops.pop_back();
1310 Expr right = exprs.back();
1311 exprs.pop_back();
1312 Expr left = exprs.back();
1313 exprs.pop_back();
1314 ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op;
1315 exprs.push_back(
1316 relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span)));
1317 }
1318
1319 exprs.push_back(right);
1320 ops.push_back(op);
1321 }
1322
1323 while (ops.size()) {
1324 Rule new_op = ops.back();
1325 ops.pop_back();
1326 Expr right = exprs.back();
1327 exprs.pop_back();
1328 Expr left = exprs.back();
1329 exprs.pop_back();
1330 ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op;
1331 exprs.push_back(
1332 relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span)));
1333 }
1334
1335 ICHECK_EQ(ops.size(), 0) << "No operations should be left on the operation stack.";
1336
1337 ICHECK_EQ(exprs.size(), 1)
1338 << "Only a single expression should be left on the expression stack.";
1339
1340 return exprs[0];
1341 });
1342 }
1343
1344 ObjectRef ParseAttributeValue() {
1345 VLOG(9) << "Parser::ParseAttributeValue";
1346 auto next = Peek();
1347 switch (next->token_type) {
1348 case TokenType::kFloat:
1349 case TokenType::kInteger:
1350 case TokenType::kBoolean:
1351 case TokenType::kStringLiteral:
1352 return Match(next->token_type)->data;
1353 case TokenType::kMetaReference:
1354 return ParseMetaRef();
1355 case TokenType::kLSquare: {
1356 return ParseSequence<ObjectRef>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare,
1357 [&]() { return ParseAttributeValue(); });
1358 }
1359 case TokenType::kOpenParen: {
1360 // TODO(@jroesch: need to figure out bracket vs. sequence)
1361 // return ParseSequence<ObjectRef>(TokenType::kOpenParen, TokenType::kComma,
1362 // TokenType::kCloseParen,
1363 // [&]() { return ParseAttributeValue(); });
1364 return Bracket<ObjectRef>(TokenType::kOpenParen, TokenType::kCloseParen,
1365 [&]() { return ParseAttributeValue(); });
1366 }
1367 // TODO(@jroesch): not sure about this being the right way to handle nulls.
1368 case TokenType::kIdentifier: {
1369 if (auto text = next->data.as<tvm::StringObj>()) {
1370 std::string id = GetRef<String>(text);
1371 if (id == "nullptr") {
1372 Match(TokenType::kIdentifier);
1373 return ObjectRef();
1374 }
1375 if (id == "None") {
1376 Match(TokenType::kIdentifier);
1377 return Optional<ObjectRef>();
1378 }
1379 }
1380 }
1381 default:
1382 return ParseAtomicExpr();
1383 }
1384 }
1385
1386 Map<String, ObjectRef> ParseAttrs() {
1387 VLOG(9) << "Parser::ParseAttrs";
1388 Map<String, ObjectRef> kwargs;
1389 while (Peek()->token_type == TokenType::kIdentifier) {
1390 auto key = GetHierarchicalName(ParseHierarchicalName().data);
1391 Match(TokenType::kEqual);
1392 // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side.
1393 auto value = ParseAttributeValue();
1394 // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text
1395 // format is bad.
1396 kwargs.Set(key, value);
1397 WhenMatch(TokenType::kComma);
1398 }
1399 VLOG(9) << "Parser::ParseAttrs: kwargs=" << kwargs;
1400 return kwargs;
1401 }
1402
1403 Expr ParseCallArgs(Expr op) {
1404 ICHECK(op.defined()) << "the operator must be defined";
1405
1406 VLOG(9) << "Parser::ParseCallArgs";
1407 Attrs attrs;
1408 std::string op_key;
1409 bool is_op = false;
1410
1411 if (auto op_node = op.as<OpNode>()) {
1412 is_op = true;
1413 op_key = op_node->attrs_type_key;
1414 }
1415
1416 if (Peek()->token_type == TokenType::kOpenParen) {
1417 Array<Expr> args = ParseSequence<Expr>(
1418 TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
1419 [&] { return ParseExpr(); },
1420 [&] {
1421 auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
1422 auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual;
1423 auto is_pretty_attrs = is_ident && next_is_equal;
1424 auto is_meta_next = Lookahead(1)->token_type == TokenType::kMetaReference;
1425 // TODO(@jroesch): might not handle trailing comma
1426 auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen;
1427 auto is_meta_attrs = is_meta_next && last_meta;
1428
1429 if (is_pretty_attrs || is_meta_attrs) {
1430 if (is_meta_attrs) {
1431 auto meta_ref = ParseMetaRef();
1432 if (meta_ref.as<BaseAttrsNode>()) {
1433 attrs = Downcast<Attrs>(meta_ref);
1434 } else {
1435 // Not awesome parsing code here.
1436 this->pos--;
1437 return false;
1438 }
1439 } else {
1440 auto raw_attrs = ParseAttrs();
1441 if (is_op && op_key.size()) {
1442 auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs);
1443 ICHECK(attr_obj.defined());
1444 attrs = Downcast<Attrs>(attr_obj);
1445 } else if (raw_attrs.count("attrs_type_key")) {
1446 String attr_key = Downcast<String>(raw_attrs["attrs_type_key"]);
1447 if (attr_key.size()) {
1448 raw_attrs.erase("attrs_type_key");
1449 auto attr_obj =
1450 tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs);
1451 ICHECK(attr_obj.defined());
1452 attrs = Downcast<Attrs>(attr_obj);
1453 }
1454 } else {
1455 this->diag_ctx.EmitFatal(Diagnostic::Error(op->span)
1456 << "unable to determine the 'attrs_type_key' with which "
1457 "to represent the call attributes for this operator");
1458 }
1459 }
1460 return true;
1461 }
1462 return false;
1463 });
1464
1465 if (!attrs.defined()) {
1466 if (is_op && op_key.size()) {
1467 auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, {});
1468 ICHECK(attr_obj.defined());
1469 attrs = Downcast<Attrs>(attr_obj);
1470 }
1471 }
1472
1473 // TODO(@jroesch): in a secondary pass adjust spans.
1474 return Expr(Call(op, args, attrs, {}));
1475 } else {
1476 return Expr();
1477 }
1478
1479 return Expr();
1480 }
1481
1482 Expr ParseCallExpr() {
1483 VLOG(9) << "Parser::ParseCallExpr";
1484 return WithSpan<Expr>([this] {
1485 Expr expr = ParseAtomicExpr();
1486 // Parse as many call args as possible, building up expression
1487 //
1488 // NB(@jroesch): this seems like a hack but in order to parse curried functions
1489 // and avoid complex grammar we will parse multiple call lists in a row.
1490 while (Peek()->token_type == TokenType::kOpenParen) {
1491 auto new_expr = ParseCallArgs(expr);
1492
1493 if (new_expr.defined()) {
1494 expr = new_expr;
1495 } else {
1496 break;
1497 }
1498 }
1499
1500 // We need a zero-arity case for constructors.
1501 if (auto ctor_node = expr.as<ConstructorNode>()) {
1502 if (ctor_node->inputs.size() == 0) {
1503 return Expr(Call(expr, {}));
1504 }
1505 }
1506
1507 return expr;
1508 });
1509 }
1510
1511 Expr GetOp(const std::string& op_name, const Span& span) {
1512 VLOG(9) << "op_name=" << op_name << " span=" << span;
1513 try {
1514 return Op::Get(op_name);
1515 } catch (const Error& e) {
1516 // we can relax this, but probably need to relax checks or return non-null here.
1517 this->diag_ctx.EmitFatal(Diagnostic::Error(span)
1518 << "operator `" << op_name
1519 << "` not found, perhaps you forgot to register it?");
1520 return Expr();
1521 }
1522 }
1523
1524 Expr ParseAtomicExpr() {
1525 VLOG(9) << "Parser::ParseAtomicExpr";
1526 Expr expr = WithSpan<Expr>([this] {
1527 auto next = Peek();
1528 switch (next->token_type) {
1529 case TokenType::kInteger:
1530 case TokenType::kFloat: {
1531 Consume(next->token_type);
1532 auto number = NumberToNDArray(next);
1533 Expr e = Constant(number, next->span);
1534 ICHECK(e->span.defined()) << "constant spans must be defined";
1535 return e;
1536 }
1537 case TokenType::kBoolean: {
1538 Consume(TokenType::kBoolean);
1539 int64_t value = Downcast<tvm::Integer>(next->data).IntValue();
1540 Expr e = Constant(support::BoolToNDArray(value), next->span);
1541 ICHECK(e->span.defined()) << "constant spans must be defined";
1542 return e;
1543 }
1544 // Parse a local of the form `%x`.
1545 case TokenType::kLocal: {
1546 Consume(TokenType::kLocal);
1547 return Expr(LookupLocal(next));
1548 }
1549 // Parse a local of the form `@x`.
1550 case TokenType::kGlobal: {
1551 auto global_name = next.ToString();
1552 Consume(TokenType::kGlobal);
1553 auto global = AddOrGet(&global_names, global_name);
1554 return Expr(global);
1555 }
1556 // Parse a local of the form `x`.
1557 // Right now we fail to parse `x.y`.
1558 case TokenType::kIdentifier: {
1559 auto ctor = ctors.Get(next.ToString());
1560 if (ctor) {
1561 Consume(TokenType::kIdentifier);
1562 return Expr(ctor.value());
1563 } else {
1564 auto spanned_idents = ParseHierarchicalName();
1565 auto idents = spanned_idents.data;
1566 auto span = spanned_idents.span;
1567 return GetOp(GetHierarchicalName(idents), span);
1568 }
1569 }
1570 case TokenType::kGraph: {
1571 Consume(TokenType::kGraph);
1572 return LookupGraphBinding(next);
1573 }
1574 case TokenType::kMetaReference: {
1575 return Downcast<Expr>(ParseMetaRef());
1576 }
1577 case TokenType::kFn: {
1578 Consume(TokenType::kFn);
1579 Expr e = ParseFunctionDef();
1580 ICHECK(e->span.defined()) << "function spans must be defined.\n" << e;
1581 return e;
1582 }
1583 case TokenType::kIf: {
1584 Expr e = ParseIf();
1585 return e;
1586 }
1587 case TokenType::kRef: {
1588 Consume(TokenType::kRef);
1589 Match(TokenType::kOpenParen);
1590 auto ref_value = ParseExpr();
1591 Match(TokenType::kCloseParen);
1592 return static_cast<Expr>(RefCreate(ref_value));
1593 }
1594 case TokenType::kRefRead: {
1595 return WithSpan<Expr>([&]() {
1596 Consume(TokenType::kRefRead);
1597 Match(TokenType::kOpenParen);
1598 auto ref = ParseExpr();
1599 Match(TokenType::kCloseParen);
1600 return static_cast<Expr>(RefRead(ref));
1601 });
1602 }
1603 case TokenType::kRefWrite: {
1604 return WithSpan<Expr>([&]() {
1605 Consume(TokenType::kRefWrite);
1606 Match(TokenType::kOpenParen);
1607 auto ref = ParseExpr();
1608 Match(TokenType::kComma);
1609 auto value = ParseExpr();
1610 Match(TokenType::kCloseParen);
1611 return static_cast<Expr>(RefWrite(ref, value));
1612 });
1613 }
1614 case TokenType::kOpenParen: {
1615 Span sp = next->span;
1616 Consume(TokenType::kOpenParen);
1617 // parse '(' ')'
1618 if (WhenMatch(TokenType::kCloseParen)) {
1619 return Expr(Tuple(Array<Expr>()));
1620 } else {
1621 Expr subexpr = ParseExpr();
1622 // parse '(' expr ')'
1623 if (WhenMatch(TokenType::kCloseParen)) {
1624 return subexpr;
1625 // parse '( expr ',' * ')'
1626 } else if (WhenMatch(TokenType::kComma)) {
1627 Array<Expr> exprs = {subexpr};
1628 while (true) {
1629 if (WhenMatch(TokenType::kCloseParen)) {
1630 break;
1631 } else {
1632 auto element = ParseExpr();
1633 auto comma = Peek();
1634 if (WhenMatch(TokenType::kComma)) {
1635 sp = sp.Merge(element->span.Merge(comma->span));
1636 } else {
1637 sp = sp.Merge(element->span);
1638 }
1639 exprs.push_back(element);
1640 }
1641 }
1642 Expr tuple = Tuple(exprs, sp);
1643 ICHECK(tuple->span.defined()) << "tuple span should be defined";
1644 return tuple;
1645 }
1646 }
1647 }
1648 default: {
1649 this->diag_ctx.EmitFatal(Diagnostic::Error(next->span)
1650 << "expected an expression found " << Pretty(next->token_type));
1651 return Expr();
1652 }
1653 }
1654 });
1655
1656 if (WhenMatch(TokenType::kPeriod)) {
1657 auto token = Match(TokenType::kInteger);
1658 auto index = token.ToNumber();
1659 auto span = token->span.Merge(expr->span);
1660 VLOG(9) << "Parser::ParseAtomicExpr: tuple get item";
1661 return relay::TupleGetItem(expr, index, span);
1662 } else {
1663 return expr;
1664 }
1665 }
1666
1667 /*! \brief Parse a hierarchical name.
1668 *
1669 * The tokenizer produces a token stream of <id1> . <id2>
1670 * and so on for names of the form `nn.conv2d`.
1671 * Currently we only use string names everywhere instead
1672 * of a notion of a hierarchical name.
1673 *
1674 * The below utility reassembles a token stream into a
1675 * single stream inserting the required periods needed
1676 * to look up registered names.
1677 */
1678 Spanned<Array<String>> ParseHierarchicalName() {
1679 Array<String> idents;
1680 Span span;
1681 while (Peek()->token_type == TokenType::kIdentifier) {
1682 auto token = Peek();
1683
1684 if (span.defined()) {
1685 span = span.Merge(token->span);
1686 } else {
1687 span = token->span;
1688 }
1689
1690 auto name = token.ToString();
1691 idents.push_back(name);
1692 Consume(TokenType::kIdentifier);
1693
1694 // Keep parsing while we see a trailing period.
1695 if (Peek()->token_type == TokenType::kPeriod) {
1696 Consume(TokenType::kPeriod);
1697 continue;
1698 } else {
1699 // No more periods means we are done!
1700 break;
1701 }
1702 }
1703
1704 return Spanned<Array<String>>(idents, span);
1705 }
1706
1707 std::string GetHierarchicalName(Array<String> idents) {
1708 ICHECK_NE(idents.size(), 0);
1709 std::stringstream hierarchical_name;
1710 int i = 0;
1711 int periods = idents.size() - 1;
1712 for (auto ident : idents) {
1713 hierarchical_name << ident;
1714 if (i < periods) {
1715 hierarchical_name << ".";
1716 i++;
1717 }
1718 }
1719 return hierarchical_name.str();
1720 }
1721
1722 /*! \brief Parse a shape. */
1723 Array<tvm::PrimExpr> ParseShape() {
1724 auto dims = ParseSequence<tvm::PrimExpr>(
1725 TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() {
1726 tvm::PrimExpr dim;
1727 if (Peek()->token_type == TokenType::kMetaReference) {
1728 dim = Downcast<tvm::PrimExpr>(ParseMetaRef());
1729 } else if (WhenMatch(TokenType::kQuestion)) {
1730 dim = tvm::tir::Any();
1731 } else {
1732 dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data);
1733 }
1734
1735 return dim;
1736 });
1737 return dims;
1738 }
1739
1740 /*! \brief Parse a function type. */
1741 Type ParseFunctionType() {
1742 auto ty_params = ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma,
1743 TokenType::kCloseParen, [&]() { return ParseType(); });
1744
1745 Match(TokenType::kMinus);
1746 Match(TokenType::kRAngle);
1747 auto ret_type = ParseType();
1748
1749 return relay::FuncType(ty_params, ret_type, {}, {});
1750 }
1751
1752 // Parses a user defined ADT or type variable.
1753 Type ParseNonPrimitiveType(const Token& tok) {
1754 return WithSpan<Type>([&]() {
1755 auto name = tok.ToString();
1756 Type head_type = LookupTypeVar(tok);
1757
1758 if (!head_type.defined()) {
1759 // head_type = type_names.Get(name);
1760 head_type = AddOrGet(&type_names, name, TypeKind::kAdtHandle);
1761 }
1762
1763 if (!head_type.defined()) {
1764 diag_ctx.EmitFatal(Diagnostic::Error(tok->span)
1765 << "the type constructor `" << name << "` is undefined");
1766 }
1767
1768 Array<Type> arg_types;
1769 if (Peek()->token_type == TokenType::kLSquare) {
1770 arg_types = ParseSequence<Type>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare,
1771 [&]() { return ParseType(); });
1772 }
1773
1774 if (arg_types.size()) {
1775 return static_cast<Type>(TypeCall(head_type, arg_types));
1776 } else {
1777 if (head_type.as<GlobalTypeVarNode>()) {
1778 return static_cast<Type>(TypeCall(head_type, {}));
1779 } else {
1780 return static_cast<Type>(head_type);
1781 }
1782 }
1783 });
1784 }
1785
1786 /*! \brief Parses a TVM type.
1787 *
1788 * This matches either a `Tensor[shape, dtype]`, a user defined ADT, a tuple type,
1789 * a scalar type or an incomplete type `_`.
1790 */
1791 Type ParseType() {
1792 return WithSpan<Type>([&]() -> Type {
1793 auto tok = Peek();
1794
1795 if (tok->token_type == TokenType::kOpenParen) {
1796 auto tys =
1797 ParseSequence<relay::Type>(TokenType::kOpenParen, TokenType::kComma,
1798 TokenType::kCloseParen, [&]() { return ParseType(); });
1799 return relay::TupleType(tys);
1800 } else if (WhenMatch(TokenType::kFn)) {
1801 return ParseFunctionType();
1802 } else if (WhenMatch(TokenType::kIdentifier)) {
1803 auto id = tok.ToString();
1804 if (id == "Tensor") {
1805 Match(TokenType::kLSquare);
1806 auto shape = ParseShape();
1807 Match(TokenType::kComma);
1808 auto dtype_tok = Match(TokenType::kIdentifier);
1809 auto dtype = DataType(String2DLDataType(dtype_tok.ToString()));
1810 Match(TokenType::kRSquare);
1811 return TensorType(shape, dtype);
1812 } else {
1813 auto ty = tok.ToString();
1814 if (ty.rfind("int", 0) == 0 || ty.find("float", 0) == 0 || ty.find("uint", 0) == 0 ||
1815 ty.find("bool", 0) == 0) {
1816 // Need to do better error handling here.
1817 auto dtype = DataType(String2DLDataType(tok.ToString()));
1818 return TensorType({}, dtype);
1819 } else {
1820 return ParseNonPrimitiveType(tok);
1821 }
1822 }
1823 } else if (WhenMatch(TokenType::kUnderscore)) {
1824 return IncompleteType();
1825 } else {
1826 this->diag_ctx.EmitFatal(Diagnostic::Error(tok->span)
1827 << "failed to parse type found " << tok);
1828 return Type();
1829 }
1830 });
1831 }
1832
1833 template <typename R>
1834 R ConsumeWhitespace(std::function<R()> func) {
1835 auto old = this->ignore_whitespace;
1836 this->ignore_whitespace = true;
1837 while (tokens[pos]->token_type == TokenType::kWhitespace) {
1838 pos++;
1839 }
1840 auto res = func();
1841 this->ignore_whitespace = old;
1842 return res;
1843 }
1844
1845 Map<String, Array<ObjectRef>> ParseMetadata() {
1846 if (Peek()->token_type == TokenType::kMetadata) {
1847 return Match(TokenType::kMetadata).ToMetadata();
1848 } else {
1849 return Map<String, Array<ObjectRef>>();
1850 }
1851 }
1852
1853 /*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */
1854 void DisplayNextN(int n) {
1855 std::cout << "remaining tokens: " << std::endl;
1856 auto bound = std::min(pos + n, static_cast<int>(tokens.size()));
1857 for (int i = 0; i < bound - pos; i++) {
1858 std::cout << tokens[pos + i] << std::endl;
1859 }
1860 }
1861
1862 // A function for debugging the operator parser.
1863 void DebugStack(const std::vector<Expr>& exprs, const std::vector<Rule>& rules) {
1864 std::cout << "Expr Stack: ";
1865 for (auto expr : exprs) {
1866 std::cout << expr << ", ";
1867 }
1868
1869 std::cout << std::endl;
1870 std::cout << "Op Stack: ";
1871 for (auto rule : rules) {
1872 std::cout << rule.op << ", ";
1873 }
1874
1875 std::cout << std::endl;
1876 }
1877};
1878
1879Parser InitParser(const std::string& file_name, const std::string& file_content,
1880 const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
1881 VLOG(9) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size();
1882 SourceName src_name = SourceName::Get(file_name);
1883 Source source(src_name, file_content);
1884
1885 IRModule module;
1886 if (!init_module) {
1887 SourceMap source_map;
1888 module = IRModule({}, {}, {}, source_map);
1889 } else {
1890 module = init_module.value();
1891 }
1892
1893 module->source_map.Add(source);
1894
1895 auto diag_ctx = DiagnosticContext::Default(module);
1896 auto tokens_and_table = Tokenize(diag_ctx, source);
1897
1898 auto tokens = tokens_and_table.first;
1899 MetaTable meta_data_table = tokens_and_table.second.ToMetadata();
1900
1901 // Merge any entries in init_meta_table into anything captured in the #[metadata] section
1902 // of the file_content. Metadata references within file_content must use indexes which account
1903 // for this ordering.
1904 for (const auto& pair : init_meta_table) {
1905 Array<ObjectRef> items;
1906 if (meta_data_table.count(pair.first)) {
1907 items = meta_data_table[pair.first];
1908 }
1909 for (const auto& obj : pair.second) {
1910 items.push_back(obj);
1911 }
1912 meta_data_table.Set(pair.first, items);
1913 }
1914
1915 return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table));
1916}
1917
1918IRModule ParseModule(const std::string& file_name, const std::string& file_content,
1919 const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
1920 VLOG_CONTEXT << "ParseModule";
1921 VLOG(9) << "parsing and type-checking " << file_name;
1922 auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
1923 auto mod = parser.ParseModule();
1924 ICHECK(mod.defined()) << "The parser must return a non-null module.";
1925 // NB(@jroesch): it is very important that we render any errors before we proceed
1926 // if there were any errors which allow the parser to proceed we must render them
1927 // here.
1928 parser.diag_ctx.Render();
1929 auto infer_type = tvm::relay::transform::InferType();
1930 ICHECK(infer_type.defined()) << "The type inferencer must be non-null.";
1931 return infer_type(mod);
1932}
1933
1934Expr ParseExpr(const std::string& file_name, const std::string& file_content) {
1935 VLOG(9) << "ParseExpr";
1936 auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable());
1937 parser.ParseSemVer(false);
1938 parser.PushScope();
1939 auto expr = parser.ParseExpr();
1940 parser.Match(TokenType::kEndOfFile);
1941 // NB(@jroesch): it is very important that we render any errors before we proceed
1942 // if there were any errors which allow the parser to proceed we must render them
1943 // here.
1944 parser.diag_ctx.Render();
1945 return expr;
1946}
1947
1948/*!
1949 * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
1950 * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
1951 * modules constructed programaticaly rather than textually.
1952 */
1953Pass AnnotateSpans() {
1954 auto pass_func = [](const IRModule& mod, const PassContext& ctx) {
1955 String text = AsText(mod, /*show_meta_data=*/true);
1956 VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
1957 return ParseModule("GeneratedSource", text);
1958 };
1959 return CreateModulePass(pass_func, 0, "AnnotateSpans", {});
1960}
1961
1962TVM_REGISTER_GLOBAL("relay.parser.ParseModuleInContext")
1963 .set_body_typed([](const std::string& file_name, const std::string& file_content,
1964 const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
1965 return ParseModule(file_name, file_content, init_module, init_meta_table);
1966 });
1967
1968TVM_REGISTER_GLOBAL("relay.parser.ParseModule").set_body([](TVMArgs args, TVMRetValue* ret) {
1969 ICHECK(args.size() >= 2 && args.size() <= 4) << "Expected 2-4 arguments, but got " << args.size();
1970 if (args.size() == 2) {
1971 *ret = ParseModule(args[0], args[1]);
1972 } else if (args.size() == 3) {
1973 *ret = ParseModule(args[0], args[1], args[2]);
1974 } else {
1975 *ret = ParseModule(args[0], args[1], args[2], args[3]);
1976 }
1977});
1978
1979TVM_REGISTER_GLOBAL("relay.parser.ParseExpr")
1980 .set_body_typed([](tvm::String file_name, tvm::String file_content) {
1981 return ParseExpr(file_name, file_content);
1982 });
1983
1984TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans);
1985
1986} // namespace relay
1987} // namespace tvm
1988