1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file parser.cc |
22 | * \brief A parser for TVM IR. |
23 | */ |
24 | #include <tvm/ir/module.h> |
25 | #include <tvm/node/reflection.h> |
26 | #include <tvm/relay/adt.h> |
27 | #include <tvm/relay/expr.h> |
28 | #include <tvm/relay/function.h> |
29 | #include <tvm/relay/parser.h> |
30 | #include <tvm/relay/transform.h> |
31 | #include <tvm/runtime/builtin_fp16.h> |
32 | #include <tvm/runtime/logging.h> |
33 | #include <tvm/runtime/object.h> |
34 | #include <tvm/runtime/registry.h> |
35 | #include <tvm/target/virtual_device.h> |
36 | |
37 | #include <fstream> |
38 | |
39 | #include "../../support/scalars.h" |
40 | #include "./meta_ref.h" |
41 | #include "./op_table.h" |
42 | #include "./span_check.h" |
43 | #include "./tokenizer.h" |
44 | |
45 | namespace tvm { |
46 | namespace relay { |
47 | |
48 | /*! \brief The meta table maps from type key to a sequence of objects. */ |
49 | using MetaTable = Map<String, Array<ObjectRef>>; |
50 | |
51 | using tvm::transform::CreateModulePass; |
52 | using tvm::transform::PassContext; |
53 | |
54 | /*! \brief A helper for passing around spans with data structures with |
55 | * no span field. |
56 | */ |
57 | template <typename T> |
58 | struct Spanned { |
59 | T data; |
60 | Span span; |
61 | |
62 | Spanned() = default; |
63 | Spanned(const Spanned<T>& other) = default; |
64 | Spanned(T data, Span span) : data(data), span(span) {} |
65 | }; |
66 | |
67 | /*! \brief A wrapper structure for capturing the result of parsing |
68 | * a global definition *before* we add it to the IRModule. |
69 | * |
70 | * This enables the parser to parse everything in one pass before |
71 | * constructing the IRModule. |
72 | */ |
73 | struct GlobalFunc { |
74 | GlobalVar global; |
75 | Function function; |
76 | GlobalFunc() : global(), function() {} |
77 | GlobalFunc(GlobalVar global, Function function) : global(global), function(function) {} |
78 | GlobalFunc(const GlobalFunc& gfunc) { |
79 | this->global = gfunc.global; |
80 | this->function = gfunc.function; |
81 | } |
82 | }; |
83 | |
84 | /*! \brief A wrapper structure for capturing all top-level definitions |
85 | * when parsing a module. |
86 | */ |
87 | struct Definitions { |
88 | /*! \brief The set of global functions. */ |
89 | std::vector<GlobalFunc> funcs; |
90 | /*! \brief The set of type definitions. */ |
91 | std::vector<TypeData> types; |
92 | // TODO(@jroesch): contain meta-table below |
93 | }; |
94 | |
95 | /*! \brief A structure representing the semantic versioning information |
96 | * for a Relay program. |
97 | */ |
98 | class SemVer { |
99 | public: |
100 | int major_version; |
101 | int minor_version; |
102 | int patch_version; |
103 | |
104 | SemVer() : major_version(0), minor_version(0), patch_version(0) {} |
105 | SemVer(int major_version, int minor_version, int patch_version) |
106 | : major_version(major_version), minor_version(minor_version), patch_version(patch_version) {} |
107 | SemVer(const SemVer& other) |
108 | : major_version(other.major_version), |
109 | minor_version(other.minor_version), |
110 | patch_version(other.patch_version) {} |
111 | }; |
112 | |
113 | /*! \brief A simple wrapper around a mapping from raw string names |
114 | * to a TVM variable, type variable or other binder type. |
115 | */ |
116 | template <typename T> |
117 | struct Scope { |
118 | /*! \brief The internal map. */ |
119 | std::unordered_map<std::string, T> name_map; |
120 | }; |
121 | |
122 | /*! \brief A stack of scopes. |
123 | * |
124 | * In order to properly handle scoping we must maintain a stack of scopes. |
125 | * |
126 | * A stack allows users to write programs which contain repeated variable |
127 | * names and to properly handle both nested scopes and removal of variables |
128 | * when they go out of scope. |
129 | * |
130 | * This is the classic approach to lexical scoping. |
131 | */ |
132 | template <typename T> |
133 | class ScopeStack { |
134 | private: |
135 | std::vector<Scope<T>> scope_stack; |
136 | std::unordered_map<std::string, T> free_vars; |
137 | |
138 | public: |
139 | /*! \brief Adds a variable binding to the current scope. */ |
140 | void Add(const std::string& name, const T& value) { |
141 | if (!this->scope_stack.size()) { |
142 | LOG(FATAL) << "internal issue" ; |
143 | } |
144 | this->scope_stack.back().name_map.insert({name, value}); |
145 | } |
146 | |
147 | void AddFreeVar(const std::string& name, const T& value) { free_vars.insert({name, value}); } |
148 | |
149 | /*! \brief Looks up a variable name in the scope stack returning the matching variable |
150 | * in most recent scope. */ |
151 | T Lookup(const std::string& name) { |
152 | for (auto scope = this->scope_stack.rbegin(); scope != this->scope_stack.rend(); ++scope) { |
153 | auto it = scope->name_map.find(name); |
154 | if (it != scope->name_map.end()) { |
155 | return it->second; |
156 | } |
157 | } |
158 | |
159 | // Check if we bound a free variable declaration. |
160 | auto it = free_vars.find(name); |
161 | if (it != free_vars.end()) { |
162 | return it->second; |
163 | } |
164 | |
165 | return T(); |
166 | } |
167 | |
168 | /*! \brief Adds a fresh scope. */ |
169 | void PushStack() { this->scope_stack.push_back(Scope<T>()); } |
170 | |
171 | /*! \brief Removes the most recent scope. */ |
172 | void PopStack() { this->scope_stack.pop_back(); } |
173 | }; |
174 | |
175 | struct DuplicateKeyError : public Error { |
176 | explicit DuplicateKeyError(const std::string& msg) : Error(msg) {} |
177 | }; |
178 | |
179 | /*! \brief A table of interning strings as global function and type names. */ |
180 | template <typename T> |
181 | struct InternTable { |
182 | /*! \brief The internal table mapping strings to a unique allocation. */ |
183 | std::unordered_map<std::string, T> table; |
184 | DiagnosticContext* ctx; |
185 | |
186 | /*! \brief Add the unique allocation. */ |
187 | void Add(const std::string& name, const T& t) { |
188 | auto it = table.find(name); |
189 | if (it != table.end()) { |
190 | throw DuplicateKeyError("duplicate key name in intern table" ); |
191 | } else { |
192 | table.insert({name, t}); |
193 | } |
194 | } |
195 | |
196 | /*! \brief Return the unique allocation. */ |
197 | Optional<T> Get(const std::string& name) const { |
198 | auto it = table.find(name); |
199 | if (it != table.end()) { |
200 | return Optional<T>(it->second); |
201 | } else { |
202 | return Optional<T>(); |
203 | } |
204 | } |
205 | }; |
206 | |
207 | GlobalVar AddOrGet(InternTable<GlobalVar>* table, const std::string& name) { |
208 | auto var = table->Get(name); |
209 | if (var) { |
210 | return var.value(); |
211 | } else { |
212 | auto gvar = GlobalVar(name); |
213 | table->Add(name, gvar); |
214 | return gvar; |
215 | } |
216 | } |
217 | |
218 | GlobalTypeVar AddOrGet(InternTable<GlobalTypeVar>* table, const std::string& name, |
219 | TypeKind kind = TypeKind::kType) { |
220 | auto var = table->Get(name); |
221 | if (var) { |
222 | auto tvar = var.value(); |
223 | TypeKind& tvar_kind = const_cast<TypeKind&>(tvar->kind); |
224 | tvar_kind = kind; |
225 | return tvar; |
226 | } else { |
227 | auto gvar = GlobalTypeVar(name, kind); |
228 | table->Add(name, gvar); |
229 | return gvar; |
230 | } |
231 | } |
232 | |
233 | /*! \brief The parser class is the main interface to the parser. |
234 | * the parser is not currently exposed beyond this .cc file. |
235 | * |
236 | * The parser is initialized with a diagnostic context, an |
237 | * operator table, and a token stream. |
238 | * |
239 | * The rest of the internal state is used to map the human readable |
240 | * form to in-memory IR representation. |
241 | * |
242 | * The main entry point to the parser are a set of parsing methods |
243 | * such as `ParseModule` and `ParseExpr`. |
244 | * |
245 | * As with traditional recursive descent parsers the parsing methods |
246 | * are factored recursively just as one would do with a formal language |
247 | * grammar. |
248 | * |
249 | * You can view a recursive descent parser as a human friendly way to specify |
250 | * a state machine, and thus this factoring is necessary as the 'state' of this |
251 | * machine is the combination of the current parsing method and the next token. |
252 | * |
253 | * Parsing proceeds by matching a token and then dispatching to the appropriate |
254 | * method to parse the next tokens in the stream. |
255 | * |
256 | * For example if we are parsing a type and encounter a "Tensor" token we switch |
257 | * into a mode for parsing `[`, a shape, a comma, a data type and then a `]`. |
258 | * |
259 | * Certain matches like this are unambiguous and proceed in a straight line fashion |
260 | * once the initial token is found. Other parsing is more complex and requires some |
261 | * tricks to correctly parse. |
262 | * |
263 | * For example when we find a '(' in an expression context, it may be part of |
264 | * a tuple, the arguments to a call, or a parenthesized expression. The below code |
265 | * disambiguate these cases by factoring expression parsing into a series of methods |
266 | * which encode the parsing context and thus how to interpret the parenthesis. |
267 | * |
268 | * For more information one should be able to read the code in order starting with |
269 | * `ParseModule` or `ParseExpr`. |
270 | */ |
271 | class Parser { |
272 | public: |
273 | /*! \brief The version that the parser is parsing. */ |
274 | SemVer version; |
275 | |
276 | /*! \brief The IRModule we are building. */ |
277 | IRModule module; |
278 | |
279 | /*! \brief The diagnostic context used for error reporting. */ |
280 | DiagnosticContext diag_ctx; |
281 | |
282 | const Source& source; |
283 | |
284 | /*! \brief The current position in the token stream. */ |
285 | int pos; |
286 | |
287 | /*! \brief The token stream for the parser. */ |
288 | std::vector<Token> tokens; |
289 | |
290 | /*! \brief The configured operator table. */ |
291 | OperatorTable op_table; |
292 | |
293 | /*! \brief Configure the whitespace mode, right now we ignore all whitespace. */ |
294 | bool ignore_whitespace; |
295 | |
296 | /*! \brief A global mapping for GlobalVar. */ |
297 | InternTable<GlobalVar> global_names; |
298 | |
299 | /*! \brief A global mapping for type definitions. */ |
300 | InternTable<GlobalTypeVar> type_names; |
301 | |
302 | /*! \brief A global mapping for constructor names. */ |
303 | InternTable<Constructor> ctors; |
304 | |
305 | /*! \brief A mapping from graph variable to expression, i.e., `%0 = expr`. */ |
306 | std::unordered_map<int, Expr> graph_ctx; |
307 | |
308 | /*! \brief The set of type scopes used for generics. */ |
309 | ScopeStack<TypeVar> type_scopes; |
310 | |
311 | /*! \brief The set of expression scopes used for lexical scope. */ |
312 | ScopeStack<Var> expr_scopes; |
313 | |
314 | /*! \brief The metadata section. */ |
315 | MetaTable meta_table; |
316 | |
317 | Parser(IRModule module, DiagnosticContext ctx, const Source& source, std::vector<Token> tokens, |
318 | OperatorTable op_table, MetaTable table) |
319 | : module(module), |
320 | diag_ctx(ctx), |
321 | source(source), |
322 | pos(0), |
323 | tokens(tokens), |
324 | op_table(op_table), |
325 | ignore_whitespace(true), |
326 | meta_table(table) { |
327 | InitializeGlobals(); |
328 | InitializeTypeDefs(); |
329 | } |
330 | |
331 | /*! If we are parsing into a module with previously loaded data types we need to |
332 | * map constructor names and variable names in the global tables. |
333 | */ |
334 | void InitializeTypeDefs() { |
335 | for (auto pair : this->module->type_definitions) { |
336 | type_names.Add(pair.first->name_hint, pair.first); |
337 | for (auto ctor : pair.second->constructors) { |
338 | ctors.Add(ctor->name_hint, ctor); |
339 | } |
340 | } |
341 | } |
342 | |
343 | void InitializeGlobals() { |
344 | for (auto pair : this->module->functions) { |
345 | global_names.Add(pair.first->name_hint, pair.first); |
346 | } |
347 | } |
348 | |
349 | /*! \brief Examine the next token in the stream, the current parser is configured to be |
350 | * whitespace insensitive so we will skip all whitespace or comment tokens. */ |
351 | Token Peek() { |
352 | // For now we ignore all whitespace tokens and comments. |
353 | // We can tweak this behavior later to enable white space sensitivity in the parser. |
354 | while (pos < static_cast<int64_t>(tokens.size()) && ignore_whitespace && |
355 | (tokens.at(pos)->token_type == TokenType::kWhitespace || |
356 | tokens.at(pos)->token_type == TokenType::kNewline || |
357 | tokens.at(pos)->token_type == TokenType::kLineComment || |
358 | tokens.at(pos)->token_type == TokenType::kComment)) { |
359 | pos++; |
360 | } |
361 | |
362 | if (pos < static_cast<int64_t>(tokens.size())) { |
363 | return Token(this->tokens.at(pos)); |
364 | } else { |
365 | return Token::Null(); |
366 | } |
367 | } |
368 | |
369 | /*! \brief Lookahead by N tokens. |
370 | * \param n The number of tokens to lookahead. |
371 | * \return The Nth token. |
372 | */ |
373 | Token Lookahead(int n) { |
374 | ICHECK_GE(n, 1) << "lookahead is only valid when n >= 1" ; |
375 | |
376 | // We intend to skip n - 1 tokens, then return the nth. |
377 | auto old_pos = pos; |
378 | for (int i = 0; i < n - 1; i++) { |
379 | Peek(); |
380 | pos++; |
381 | } |
382 | |
383 | auto tok = Peek(); |
384 | pos = old_pos; |
385 | return tok; |
386 | } |
387 | |
388 | /*! \brief Consume a token, this method is the lowest level way to consume a token |
389 | * and will not ignore white space or look ahead in anyway. |
390 | * |
391 | * /param token_type The token type to match. |
392 | */ |
393 | void Consume(const TokenType& token_type) { |
394 | if (tokens[pos]->token_type != token_type) { |
395 | this->diag_ctx.EmitFatal(Diagnostic::Error(tokens[pos]->span) |
396 | << "expected a " << Pretty(token_type) << " found " |
397 | << Pretty(Peek()->token_type)); |
398 | } |
399 | pos++; |
400 | } |
401 | |
402 | /*! Match a token in the stream, this will first invoke Peek, ignoring tokens such |
403 | * as whitespace or comments returning the first meaningful token. |
404 | * |
405 | * We then try and consume the requested token, this will trigger an error if the |
406 | * current token does not match the token_type. |
407 | */ |
408 | Token Match(const TokenType& token_type) { |
409 | auto tok = Peek(); |
410 | Consume(token_type); |
411 | return tok; |
412 | } |
413 | |
414 | /*! Conditionally consume a token when it matches, this will never trigger an error |
415 | * as we guard against consuming the token before we do. |
416 | * |
417 | * Useful for matching optional tokens, effectively looksahead by one. |
418 | */ |
419 | bool WhenMatch(const TokenType& token_type) { |
420 | VLOG(9) << "Parser::WhenMatch: Peek() == " << Peek(); |
421 | if (Peek()->token_type == token_type) { |
422 | Consume(token_type); |
423 | return true; |
424 | } else { |
425 | return false; |
426 | } |
427 | } |
428 | |
429 | /* \brief Add a graph binding to the parsing context |
430 | * |
431 | * For example if we parse %0 = add(...), map 0 -> add(...), etc. |
432 | */ |
433 | void AddGraphBinding(const Token& token, const Expr& expr) { |
434 | auto graph_no = token.ToNumber(); |
435 | this->graph_ctx.insert({graph_no, expr}); |
436 | } |
437 | |
438 | /* \brief Lookup a previously bound graph variable. |
439 | * |
440 | * Note: we take tokens in all lookup methods so that we |
441 | * that we can do error reporting based on token location. |
442 | */ |
443 | Expr LookupGraphBinding(const Token& token) { |
444 | auto graph_no = token.ToNumber(); |
445 | auto it = this->graph_ctx.find(graph_no); |
446 | if (it != this->graph_ctx.end()) { |
447 | return it->second; |
448 | } else { |
449 | LOG(FATAL) << "Local variable %" << graph_no << " has not yet been defined" ; |
450 | throw; |
451 | } |
452 | } |
453 | |
454 | /*! \brief Bind a local variable in the expression scope. |
455 | * |
456 | * "x" -> Var("x"), these are needed to map from the raw string names |
457 | * to unique variable nodes. |
458 | * If a virtual device is specified, sets the virtual device of the variable. |
459 | */ |
460 | Var BindVar(const std::string& name, const relay::Type& type_annotation, |
461 | Optional<VirtualDevice> virtual_device = Optional<VirtualDevice>()) { |
462 | auto var = Var(name, type_annotation); |
463 | var->virtual_device_ = virtual_device.value_or(VirtualDevice::FullyUnconstrained()); |
464 | VLOG(1) << "Binding var named " << name << " to variable node " << PrettyPrint(var); |
465 | this->expr_scopes.Add(name, var); |
466 | return var; |
467 | } |
468 | |
469 | /*! \brief Bind a local variable in the expression scope. |
470 | * |
471 | * "x" -> Var("x"), these are needed to map from the raw string names |
472 | * to unique variable nodes. |
473 | */ |
474 | Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) { |
475 | auto var = Var(name, type_annotation); |
476 | this->expr_scopes.AddFreeVar(name, var); |
477 | return var; |
478 | } |
479 | |
480 | /*! \brief Bind a type variable in the type scope. |
481 | * |
482 | * "A" -> TypeVar("A", ...), these are needed to map from raw string names |
483 | * to unique type variable nodes. |
484 | */ |
485 | TypeVar BindTypeVar(const std::string& name, const TypeKind type_kind) { |
486 | auto type_var = TypeVar(name, type_kind); |
487 | this->type_scopes.Add(name, type_var); |
488 | return type_var; |
489 | } |
490 | |
491 | /*! \brief Lookup a variable in the expression scope. |
492 | * |
493 | * Note: all lookup methods take tokens intentionally for error reporting information. |
494 | */ |
495 | Var LookupLocal(const Token& local) { |
496 | auto var = this->expr_scopes.Lookup(local.ToString()); |
497 | if (!var.defined()) { |
498 | diag_ctx.Emit(Diagnostic::Error(local->span) |
499 | << "this local variable has not been previously declared" ); |
500 | } |
501 | return var; |
502 | } |
503 | |
504 | /*! \brief Lookup a variable in the type scope. |
505 | * |
506 | * Note: all lookup methods take tokens intentionally for error reporting information. |
507 | */ |
508 | TypeVar LookupTypeVar(const Token& ident) { |
509 | auto var = this->type_scopes.Lookup(ident.ToString()); |
510 | return var; |
511 | } |
512 | |
513 | /*! \brief Add an expression scope to the scope stack. */ |
514 | void PushScope() { this->expr_scopes.PushStack(); } |
515 | |
516 | /*! \brief Remove N expression scopes from the scope stack. */ |
517 | void PopScopes(int n) { |
518 | for (int i = 0; i < n; i++) { |
519 | this->expr_scopes.PopStack(); |
520 | } |
521 | } |
522 | |
523 | /*! \brief Add an type scope to the scope stack. */ |
524 | void PushTypeScope() { this->type_scopes.PushStack(); } |
525 | |
526 | /*! \brief Remove N type scopes from the scope stack. */ |
527 | void PopTypeScopes(int n) { |
528 | for (int i = 0; i < n; i++) { |
529 | this->type_scopes.PopStack(); |
530 | } |
531 | } |
532 | |
533 | /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ |
534 | NDArray NumberToNDArray(const Token& token) { |
535 | if (token->token_type == TokenType::kInteger) { |
536 | return support::IntImmToNDArray(Downcast<tvm::IntImm>(token->data)); |
537 | } else if (token->token_type == TokenType::kFloat) { |
538 | return support::FloatImmToNDArray(Downcast<tvm::FloatImm>(token->data)); |
539 | } else { |
540 | LOG(FATAL) << "internal error: should only call this function on numeric tokens" ; |
541 | } |
542 | } |
543 | |
544 | [[noreturn]] void ParseError(const Token& token, const std::string& msg) { |
545 | throw std::runtime_error(msg); |
546 | } |
547 | |
548 | /*! \brief A parsing helper for a bracketed expression <start> <parser> <stop>. */ |
549 | template <typename R> |
550 | R Bracket(TokenType open, TokenType close, std::function<R()> parser) { |
551 | Match(open); |
552 | R result = parser(); |
553 | Match(close); |
554 | return result; |
555 | } |
556 | |
557 | /*! \brief Parse `(` parser() `)`. */ |
558 | template <typename R> |
559 | R Parens(std::function<R()> parser) { |
560 | return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser); |
561 | } |
562 | |
563 | /*! \brief Parse `{` parser() `}`. */ |
564 | template <typename R> |
565 | R Block(std::function<R()> parser) { |
566 | return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); |
567 | } |
568 | |
569 | template <typename R> |
570 | R WithSpan(std::function<R()> parser) { |
571 | auto start_span = Peek()->span; |
572 | VLOG(9) << "WithSpan: start_span = " << start_span; |
573 | R ast = parser(); |
574 | if (ast.defined()) { |
575 | // The token at the head of the stream is now 1 past where we parsed. So we find its start |
576 | // position as its start and end, so that when we merge we only grow the spanned region |
577 | // to the start of the current stream. |
578 | auto span_pos = pos - 1; |
579 | while ((tokens.at(span_pos)->token_type == TokenType::kWhitespace || |
580 | tokens.at(span_pos)->token_type == TokenType::kNewline || |
581 | tokens.at(span_pos)->token_type == TokenType::kLineComment || |
582 | tokens.at(span_pos)->token_type == TokenType::kComment)) { |
583 | span_pos--; |
584 | } |
585 | auto end_token = tokens.at(span_pos); |
586 | VLOG(9) << "WithSpan: end_span = " << end_token->span; |
587 | ast->span = start_span.Merge(end_token->span); |
588 | } |
589 | return ast; |
590 | } |
591 | |
592 | struct MetaRef { |
593 | std::string type_key; |
594 | uint64_t node_index; |
595 | Span span; |
596 | MetaRef(std::string type_key, uint64_t node_index, Span span) |
597 | : type_key(type_key), node_index(node_index), span(span) {} |
598 | }; |
599 | |
600 | MetaRef MetaRefFromToken(const Token& tok) { |
601 | Call ref = Downcast<Call>(tok->data); |
602 | auto attrs = ref->attrs.as<MetaRefAttrs>(); |
603 | auto type_key = attrs->node_type_key; |
604 | auto index = attrs->node_index; |
605 | return MetaRef(type_key, index, ref->span); |
606 | } |
607 | |
608 | /*! \brief Parse a meta reference of the form `meta[type_key][node_index]`. |
609 | * For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]` |
610 | * the second, and so on. |
611 | */ |
612 | ObjectRef ParseMetaRef() { |
613 | auto meta_ref_tok = Match(TokenType::kMetaReference); |
614 | auto meta_ref = MetaRefFromToken(meta_ref_tok); |
615 | auto it = this->meta_table.find(meta_ref.type_key); |
616 | if (it != this->meta_table.end()) { |
617 | auto nodes = (*it).second; |
618 | if (meta_ref.node_index < nodes.size()) { |
619 | return nodes[meta_ref.node_index]; |
620 | } else { |
621 | this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) |
622 | << "the node index `" << meta_ref.node_index |
623 | << "` is out of bounds for `" << meta_ref.type_key << "`" ); |
624 | return ObjectRef(); |
625 | } |
626 | } else { |
627 | this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) |
628 | << "no entry in the meta table for `" << meta_ref.type_key << "`" ); |
629 | return ObjectRef(); |
630 | } |
631 | } |
632 | /*! \brief Parses a sequence beginning with a start token, separated by a seperator token, and |
633 | * ending with a stop token. |
634 | * |
635 | * The simple form being <start> (<parse()> <seperator>)* <stop>. |
636 | * |
637 | * This also provides a fourth argument which is allowed to run when the sequence which matches |
638 | * the inner sequence can not proceed. |
639 | * |
640 | * This is useful for parsing things like attributes which don't match the standard expression |
641 | * parsers but are contained within the stop token. |
642 | */ |
643 | template <typename T> |
644 | Array<T> ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function<T()> parse, |
645 | std::function<bool()> before_stop = nullptr) { |
646 | VLOG(9) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep) |
647 | << " stop=" << ToString(stop); |
648 | Match(start); |
649 | |
650 | // This is for the empty arguments list case, if we have <start> <leftovers> <stop> token stream |
651 | // we must parse leftovers, then match a stop token. |
652 | if (before_stop) { |
653 | auto did_parse = before_stop(); |
654 | if (did_parse) { |
655 | Match(stop); |
656 | return {}; |
657 | } |
658 | } |
659 | |
660 | // This is the case in which we find an empty arguments lists and no leftovers. |
661 | if (WhenMatch(stop)) { |
662 | return Array<T>(); |
663 | } else { |
664 | VLOG(9) << "Parser::ParseSequence: parse first" ; |
665 | auto data = parse(); |
666 | Array<T> elements = {data}; |
667 | |
668 | if (WhenMatch(stop)) { |
669 | return elements; |
670 | // parse '( expr ',' * ')' |
671 | } else if (WhenMatch(sep)) { |
672 | while (true) { |
673 | VLOG(9) << "Parser::ParseSequence: parse element" ; |
674 | if (WhenMatch(stop)) { |
675 | break; |
676 | } else { |
677 | // If before stop is |
678 | if (before_stop) { |
679 | auto did_parse = before_stop(); |
680 | if (did_parse) { |
681 | Match(stop); |
682 | return elements; |
683 | } |
684 | } |
685 | auto data = parse(); |
686 | WhenMatch(sep); |
687 | elements.push_back(data); |
688 | } |
689 | } |
690 | return elements; |
691 | } else { |
692 | auto next = Peek(); |
693 | this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) |
694 | << "expected a " << Pretty(stop) << " found " |
695 | << Pretty(next->token_type)); |
696 | return Array<T>(nullptr); |
697 | } |
698 | } |
699 | } |
700 | |
701 | /*! \brief Parse a full IRModule. */ |
702 | IRModule ParseModule() { |
703 | // Parse the semver header at the top of the module. |
704 | this->version = ParseSemVer(); |
705 | // Parse the definitions. |
706 | auto defs = ParseDefinitions(); |
707 | // Parse the metadata section at the end. |
708 | auto metadata = ParseMetadata(); |
709 | |
710 | Match(TokenType::kEndOfFile); |
711 | |
712 | for (auto type_def : defs.types) { |
713 | module->AddTypeDef(type_def->header, type_def); |
714 | } |
715 | |
716 | for (auto func : defs.funcs) { |
717 | module->Add(func.global, func.function, true); |
718 | } |
719 | |
720 | return module; |
721 | } |
722 | |
723 | /*! \brief Parse the semantic versioning header. */ |
724 | SemVer ParseSemVer(bool required = true) { |
725 | if (Peek()->token_type == TokenType::kVersion) { |
726 | auto version = Match(TokenType::kVersion); |
727 | // TODO(@jroesch): we currently only support 0.0.5. |
728 | if (version.ToString() != "\"0.0.5\"" ) { |
729 | this->diag_ctx.Emit(Diagnostic::Error(version->span) |
730 | << "invalid semantic version `" << version.ToString() << "`" ); |
731 | } |
732 | } else if (required) { |
733 | this->diag_ctx.Emit(Diagnostic::Error(Peek()->span) |
734 | << "expected text format semantic version, found a " |
735 | << PrettyPrint(Peek())); |
736 | |
737 | this->diag_ctx.Emit(Diagnostic::Help(Peek()->span) |
738 | << "you can annotate it as #[version = \"0.0.5\"]" ); |
739 | } |
740 | return SemVer(0, 0, 5); |
741 | } |
742 | |
743 | /*! \brief Parse zero or more Relay definitions. */ |
744 | Definitions ParseDefinitions() { |
745 | Definitions defs; |
746 | |
747 | while (true) { |
748 | auto next = Peek(); |
749 | switch (next->token_type) { |
750 | case TokenType::kDefn: { |
751 | Consume(TokenType::kDefn); |
752 | auto global_tok = Match(TokenType::kGlobal); |
753 | auto global_name = global_tok.ToString(); |
754 | auto global = AddOrGet(&global_names, global_name); |
755 | auto func = WithSpan<relay::Function>([&]() { return ParseFunctionDef(); }); |
756 | ICHECK(func->span.defined()) << "spans must be set in parser" ; |
757 | defs.funcs.push_back(GlobalFunc(global, func)); |
758 | continue; |
759 | } |
760 | case TokenType::kTypeDef: { |
761 | defs.types.push_back(ParseTypeDef()); |
762 | continue; |
763 | } |
764 | case TokenType::kExtern: { |
765 | Consume(TokenType::kExtern); |
766 | auto type_def = ParseTypeDef(); |
767 | if (type_def->constructors.size()) { |
768 | diag_ctx.Emit(Diagnostic::Error(next->span) |
769 | << "an external type may not have any constructors" ); |
770 | } |
771 | defs.types.push_back(type_def); |
772 | } |
773 | default: |
774 | return defs; |
775 | } |
776 | } |
777 | } |
778 | |
779 | /*! \brief Parse zero or more Relay type definitions. */ |
780 | TypeData ParseTypeDef() { |
781 | // Match the `type` keyword. |
782 | Match(TokenType::kTypeDef); |
783 | // Parse the type's identifier. |
784 | auto type_tok = Match(TokenType::kIdentifier); |
785 | auto type_id = type_tok.ToString(); |
786 | auto type_global = AddOrGet(&type_names, type_id, TypeKind::kAdtHandle); |
787 | |
788 | Array<TypeVar> generics; |
789 | |
790 | bool should_pop = false; |
791 | if (Peek()->token_type == TokenType::kLSquare) { |
792 | // If we have generics we need to add a type scope. |
793 | PushTypeScope(); |
794 | should_pop = true; |
795 | generics = ParseSequence<TypeVar>( |
796 | TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { |
797 | auto type_var_name = Match(TokenType::kIdentifier).ToString(); |
798 | return BindTypeVar(type_var_name, TypeKind::kType); |
799 | }); |
800 | } |
801 | |
802 | Array<tvm::Constructor> ctors; |
803 | if (Peek()->token_type == TokenType::kLCurly) { |
804 | // Parse the list of constructors. |
805 | ctors = ParseSequence<tvm::Constructor>( |
806 | TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() { |
807 | // First match the name of the constructor. |
808 | auto ctor_tok = Match(TokenType::kIdentifier); |
809 | auto ctor_name = ctor_tok.ToString(); |
810 | |
811 | Constructor ctor; |
812 | // Match the optional field list. |
813 | if (Peek()->token_type != TokenType::kOpenParen) { |
814 | ctor = tvm::Constructor(ctor_name, {}, type_global); |
815 | } else { |
816 | auto arg_types = |
817 | ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma, |
818 | TokenType::kCloseParen, [&]() { return ParseType(); }); |
819 | ctor = tvm::Constructor(ctor_name, arg_types, type_global); |
820 | } |
821 | |
822 | ICHECK(ctor.defined()); |
823 | |
824 | try { |
825 | this->ctors.Add(ctor_name, ctor); |
826 | } catch (const DuplicateKeyError& e) { |
827 | this->diag_ctx.EmitFatal(Diagnostic::Error(ctor_tok->span) |
828 | << "a constructor with the name " |
829 | << "`" << ctor_name << "` " |
830 | << "was previously defined" ); |
831 | } |
832 | |
833 | return ctor; |
834 | }); |
835 | } |
836 | |
837 | // Now pop the type scope. |
838 | if (should_pop) { |
839 | PopTypeScopes(1); |
840 | } |
841 | |
842 | return TypeData(type_global, generics, ctors); |
843 | } |
844 | |
845 | std::string HackTokensAsString(int n) { |
846 | std::stringstream key; |
847 | n = std::min(static_cast<int>(tokens.size() - pos), n); |
848 | for (int i = 0; i < n; i++) { |
849 | key << ToString(tokens.at(pos + i)->token_type); |
850 | } |
851 | return key.str(); |
852 | } |
853 | |
854 | std::vector<Rule> ParseOp() { |
855 | std::vector<Rule> matched; |
856 | Peek(); |
857 | for (int i = 4; i > 0; i--) { |
858 | auto key = HackTokensAsString(i); |
859 | auto it = this->op_table.this_is_a_hack.find(key); |
860 | if (it != this->op_table.this_is_a_hack.end()) { |
861 | pos = pos + i; |
862 | matched.push_back(it->second); |
863 | } |
864 | } |
865 | |
866 | return matched; |
867 | } |
868 | |
869 | /*! \brief Parse a single Relay expression. */ |
870 | Expr ParseExpr() { |
871 | VLOG(9) << "Parser::ParseExpr" ; |
872 | return WithSpan<Expr>([this] { |
873 | std::vector<Expr> exprs; |
874 | |
875 | while (true) { |
876 | VLOG(9) << "Parser::ParseExpr: parsing a single expression" ; |
877 | auto next = Peek(); |
878 | switch (next->token_type) { |
879 | // For graph or let, match first rhs, then invoke ParseBindingExpr |
880 | // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue |
881 | case TokenType::kLCurly: { |
882 | // NB: Might need to optimize to remove deep recursion. |
883 | // Stack should only grow proportionally to the number of |
884 | // nested scopes. |
885 | // Parses `{` expression `}`. |
886 | auto block = WithSpan<Expr>([&]() { |
887 | return Bracket<Expr>(TokenType::kLCurly, TokenType::kRCurly, [&]() { |
888 | PushScope(); |
889 | auto expr = ParseExpr(); |
890 | PopScopes(1); |
891 | return expr; |
892 | }); |
893 | }); |
894 | exprs.push_back(block); |
895 | break; |
896 | } |
897 | case TokenType::kFreeVar: { |
898 | Consume(TokenType::kFreeVar); |
899 | auto var_token = Match(TokenType::kLocal); |
900 | |
901 | Type type; |
902 | if (WhenMatch(TokenType::kColon)) { |
903 | type = ParseType(); |
904 | } else { |
905 | type = IncompleteType(); |
906 | } |
907 | |
908 | BindFreeVar(var_token.ToString(), type); |
909 | break; |
910 | } |
911 | // Parses `let ...`; |
912 | case TokenType::kLet: |
913 | exprs.push_back(ParseBindingExpr()); |
914 | break; |
915 | case TokenType::kMatch: |
916 | case TokenType::kPartialMatch: { |
917 | bool is_total = next->token_type == TokenType::kMatch; |
918 | Consume(next->token_type); |
919 | exprs.push_back(ParseMatch(is_total)); |
920 | break; |
921 | } |
922 | |
923 | // %x ... |
924 | case TokenType::kGraph: |
925 | if (Lookahead(2)->token_type == TokenType::kEqual) { |
926 | exprs.push_back(ParseBindingExpr()); |
927 | break; |
928 | } |
929 | // intentional fall through here. |
930 | default: { |
931 | exprs.push_back(ParseExprBinOp()); |
932 | break; |
933 | } |
934 | } |
935 | |
936 | if (!WhenMatch(TokenType::kSemicolon)) { |
937 | break; |
938 | } |
939 | } |
940 | |
941 | ICHECK_GE(exprs.size(), 1); |
942 | |
943 | if (exprs.size() == 1) { |
944 | // ICHECK(exprs[0].defined() && exprs[0]->span.defined()) |
945 | // << "parser must set expression spans.\n" |
946 | // << exprs[0]; |
947 | return exprs[0]; |
948 | } else { |
949 | auto body = exprs.back(); |
950 | exprs.pop_back(); |
951 | while (exprs.size()) { |
952 | auto value = exprs.back(); |
953 | ICHECK(value->span.defined()) << "parser must set expression spans." ; |
954 | exprs.pop_back(); |
955 | body = relay::Let(Var("" , IncompleteType()), value, body, value->span.Merge(body->span)); |
956 | } |
957 | ICHECK(body->span.defined()) << "parser must set expression spans." ; |
958 | return body; |
959 | } |
960 | }); |
961 | } |
962 | |
963 | /*! \brief Parse a "binding expression"; an expression where |
964 | * a graph or let variable is bound. |
965 | * |
966 | * In order to avoid stack overflow this is implemented in a special |
967 | * iterative way to keep stack depth constant in a long chain of bindings. |
968 | */ |
969 | Expr ParseBindingExpr() { |
970 | // We use a loop here so that the stack depth |
971 | // does not grow linearly with a sequence of |
972 | // graph or let bindings. |
973 | // |
974 | // Assuming we start at call depth k, we will |
975 | // enter k + c call frames to parse the RHS |
976 | // of the bindings where `c` is the depth |
977 | // of recursion needed by RHS. |
978 | // |
979 | // If RHS is a call expresssion the c=1. |
980 | // |
981 | // Once we have parsed the RHS we will be |
982 | // back at depth K, and will return to |
983 | // this loop header to parse another |
984 | // graph or let binding. |
985 | // |
986 | // This ensures for n sequential bindings |
987 | // the call depth will be the same before |
988 | // and after parsing the n bindings. |
989 | VLOG(9) << "Parser::ParseBindingExpr" ; |
990 | std::vector<std::tuple<Var, Expr, Span>> bindings; |
991 | int scopes = 0; |
992 | |
993 | while (true) { |
994 | auto next = Peek(); |
995 | if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) { |
996 | Match(TokenType::kGraph); |
997 | Match(TokenType::kEqual); |
998 | auto val = this->ParseExprBinOp(); |
999 | Match(TokenType::kSemicolon); |
1000 | AddGraphBinding(next, val); |
1001 | } else if (next->token_type == TokenType::kLet) { |
1002 | auto span = next->span; |
1003 | // Parse the 'let'. |
1004 | Consume(TokenType::kLet); |
1005 | |
1006 | // Parse the local '%<id>'. |
1007 | auto local_tok = Match(TokenType::kLocal); |
1008 | auto string = local_tok.ToString(); |
1009 | |
1010 | // Parse the optional type annotation (':' <type>). |
1011 | Type type; |
1012 | if (WhenMatch(TokenType::kColon)) { |
1013 | type = ParseType(); |
1014 | } |
1015 | |
1016 | auto var = BindVar(string, type); |
1017 | |
1018 | // Parse the '='; |
1019 | Match(TokenType::kEqual); |
1020 | |
1021 | // Parse the body, and the ';'. |
1022 | auto val = this->ParseExprBinOp(); |
1023 | Consume(TokenType::kSemicolon); |
1024 | |
1025 | // Add the bindings to the local data structure. |
1026 | std::tuple<relay::Var, relay::Expr, Span> tuple(var, val, span); |
1027 | bindings.push_back(tuple); |
1028 | scopes++; |
1029 | PushScope(); |
1030 | } else { |
1031 | // This is the only case we will increase the stack |
1032 | // depth. |
1033 | // |
1034 | // If we parse a program which is a sequence of N bindings |
1035 | // followed by a single body expression we will end up with |
1036 | // a call depth of 3, the first call to ParseExpr, then |
1037 | // ParseBindingExpr, then finally ParseExpr once more. |
1038 | |
1039 | auto body = this->ParseExpr(); |
1040 | |
1041 | // Remove the same number of scopes we added. |
1042 | PopScopes(scopes); |
1043 | |
1044 | if (bindings.size() == 0) { |
1045 | return body; |
1046 | } else { |
1047 | // We can now build the let binding up backwards. |
1048 | for (auto binding = bindings.rbegin(); binding != bindings.rend(); binding++) { |
1049 | auto span = body->span.Merge(std::get<2>(*binding)); |
1050 | body = relay::Let(std::get<0>(*binding), std::get<1>(*binding), body, span); |
1051 | } |
1052 | return body; |
1053 | } |
1054 | } |
1055 | } |
1056 | } |
1057 | |
1058 | /*! Parse a function definition without a leading keyword or identifier. |
1059 | * |
1060 | * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. |
1061 | */ |
1062 | Function ParseFunctionDef() { |
1063 | VLOG(9) << "Parser::ParseFunctionDef" ; |
1064 | return WithSpan<Function>([&]() { |
1065 | PushScope(); |
1066 | PushTypeScope(); |
1067 | |
1068 | Array<TypeVar> generics; |
1069 | if (Peek()->token_type == TokenType::kLSquare) { |
1070 | generics = ParseSequence<TypeVar>( |
1071 | TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { |
1072 | auto type_var_name = Match(TokenType::kIdentifier).ToString(); |
1073 | return BindTypeVar(type_var_name, TypeKind::kType); |
1074 | }); |
1075 | } |
1076 | |
1077 | Map<String, ObjectRef> raw_attrs; |
1078 | |
1079 | auto params = ParseSequence<Var>( |
1080 | TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, |
1081 | [&]() { |
1082 | auto token = Match(TokenType::kLocal); |
1083 | auto string = token.ToString(); |
1084 | |
1085 | // The fake attributes where the virtual device is specified. |
1086 | VirtualDevice virtual_device; |
1087 | if (WhenMatch(TokenType::kLCurly)) { |
1088 | Map<String, ObjectRef> fake_attrs = ParseAttrs(); |
1089 | VLOG(9) << "Fake attributes for function parameter: " << fake_attrs; |
1090 | Match(TokenType::kRCurly); |
1091 | if (fake_attrs.size() == 1 && fake_attrs.count(kVirtualDevice)) { |
1092 | ICHECK(fake_attrs[kVirtualDevice].as<VirtualDeviceNode>()) |
1093 | << "Expected the " << kVirtualDevice |
1094 | << " to have type VirtualDeviceNode, but got " << virtual_device->GetTypeKey(); |
1095 | virtual_device = Downcast<VirtualDevice>(fake_attrs[kVirtualDevice]); |
1096 | } |
1097 | } |
1098 | |
1099 | Type type; |
1100 | if (WhenMatch(TokenType::kColon)) { |
1101 | type = ParseType(); |
1102 | } |
1103 | return BindVar(string, type, virtual_device); |
1104 | }, |
1105 | [&] { |
1106 | auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; |
1107 | auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; |
1108 | |
1109 | if (is_ident && next_is_equal) { |
1110 | raw_attrs = ParseAttrs(); |
1111 | return true; |
1112 | } |
1113 | |
1114 | return false; |
1115 | }); |
1116 | |
1117 | Type ret_type; |
1118 | if (WhenMatch(TokenType::kMinus)) { |
1119 | Match(TokenType::kRAngle); |
1120 | ret_type = ParseType(); |
1121 | } |
1122 | |
1123 | auto body = Block<Expr>([&]() { return ParseExpr(); }); |
1124 | |
1125 | PopTypeScopes(1); |
1126 | PopScopes(1); |
1127 | |
1128 | // TODO(@jroesch): attributes should never be null, they should always be empty. |
1129 | if (raw_attrs.size()) { |
1130 | // Promote kVirtualDevice to first-class |
1131 | if (raw_attrs.count(kVirtualDevice)) { |
1132 | ObjectRef vid = raw_attrs.at(kVirtualDevice); |
1133 | ICHECK(vid.as<VirtualDeviceNode>()) |
1134 | << "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got " |
1135 | << vid->GetTypeKey(); |
1136 | |
1137 | DictAttrs attrs; |
1138 | // Don't fill the raw_attrs in if there's nothing other than kVirtualDevice in the |
1139 | // attributes |
1140 | if (raw_attrs.size() > 1) { |
1141 | raw_attrs.erase(kVirtualDevice); |
1142 | attrs = DictAttrs(raw_attrs); |
1143 | } |
1144 | Function func = relay::Function(params, body, ret_type, generics, attrs); |
1145 | func->virtual_device_ = vid; |
1146 | return func; |
1147 | } else { |
1148 | return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); |
1149 | } |
1150 | } else { |
1151 | return relay::Function(params, body, ret_type, generics, tvm::DictAttrs()); |
1152 | } |
1153 | }); |
1154 | } |
1155 | |
1156 | /*! \brief Parse an if-expression. */ |
1157 | Expr ParseIf() { |
1158 | return WithSpan<Expr>([&]() { |
1159 | VLOG(9) << "Parser::ParseIf" ; |
1160 | Consume(TokenType::kIf); |
1161 | |
1162 | auto guard = WithSpan<Expr>([&] { return Parens<Expr>([&] { return ParseExpr(); }); }); |
1163 | |
1164 | auto true_branch = Block<Expr>([&] { |
1165 | this->PushScope(); |
1166 | auto expr = ParseExpr(); |
1167 | this->PopScopes(1); |
1168 | return expr; |
1169 | }); |
1170 | |
1171 | Match(TokenType::kElse); |
1172 | |
1173 | auto false_branch = Block<Expr>([&] { |
1174 | this->PushScope(); |
1175 | auto expr = ParseExpr(); |
1176 | this->PopScopes(1); |
1177 | return expr; |
1178 | }); |
1179 | |
1180 | return relay::If(guard, true_branch, false_branch); |
1181 | }); |
1182 | } |
1183 | |
1184 | /* This factors parsing a list of patterns for both tuples, and constructors. */ |
1185 | Array<Pattern> ParsePatternList() { |
1186 | return ParseSequence<Pattern>(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, |
1187 | [&] { return ParsePattern(); }); |
1188 | } |
1189 | |
1190 | /*! \brief Parses a pattern for a match expression. |
1191 | * |
1192 | * A pattern is either a wildcard `_`, a local `%name`, |
1193 | * a constructor `C(p1, ..., pn)` or tuple `(p1, ..., pn). |
1194 | * |
1195 | * This function recursively parses a pattern. |
1196 | */ |
1197 | Pattern ParsePattern() { |
1198 | VLOG(9) << "Parser::ParsePattern" ; |
1199 | auto next = Peek(); |
1200 | switch (next->token_type) { |
1201 | case TokenType::kUnderscore: { |
1202 | Match(TokenType::kUnderscore); |
1203 | return PatternWildcard(); |
1204 | } |
1205 | case TokenType::kLocal: { |
1206 | auto id = Match(TokenType::kLocal); |
1207 | Type type_annotation; |
1208 | if (WhenMatch(TokenType::kColon)) { |
1209 | type_annotation = ParseType(); |
1210 | } |
1211 | auto var = BindVar(id.ToString(), type_annotation); |
1212 | return PatternVar(var); |
1213 | } |
1214 | case TokenType::kIdentifier: { |
1215 | auto id = Match(TokenType::kIdentifier); |
1216 | auto ctor = ctors.Get(id.ToString()); |
1217 | if (!ctor) { |
1218 | diag_ctx.EmitFatal( |
1219 | // TODO(@jroesch): split into error and help |
1220 | // deal with multiple rendering |
1221 | Diagnostic::Error(id->span) |
1222 | << "undefined constructor name `" << id.ToString() |
1223 | << "`, perhaps you intended to write a" |
1224 | << "pattern variable, considering changing this to `%" << id.ToString() << "`" ); |
1225 | } |
1226 | if (Peek()->token_type == TokenType::kOpenParen) { |
1227 | auto fields = ParsePatternList(); |
1228 | return PatternConstructor(ctor.value(), fields); |
1229 | } else { |
1230 | return PatternConstructor(ctor.value(), {}); |
1231 | } |
1232 | } |
1233 | default: |
1234 | return PatternTuple(ParsePatternList()); |
1235 | } |
1236 | } |
1237 | |
1238 | Clause ParseMatchArm() { |
1239 | PushScope(); |
1240 | auto pattern = ParsePattern(); |
1241 | Match(TokenType::kEqual); |
1242 | Consume(TokenType::kRAngle); |
1243 | auto expr = ParseExpr(); |
1244 | PopScopes(1); |
1245 | return Clause(pattern, expr); |
1246 | } |
1247 | |
1248 | Expr ParseMatch(bool is_total) { |
1249 | return WithSpan<Expr>([&]() { |
1250 | Expr scrutinee = ParseAtomicExpr(); |
1251 | |
1252 | Array<Clause> clauses = |
1253 | ParseSequence<Clause>(TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, |
1254 | [&] { return ParseMatchArm(); }); |
1255 | |
1256 | return relay::Match(scrutinee, clauses, is_total); |
1257 | }); |
1258 | } |
1259 | |
1260 | Expr ParseExprBinOp() { |
1261 | VLOG(9) << "Parser::ParseExprBinOp" ; |
1262 | return WithSpan<Expr>([this] { |
1263 | // We must parse at least one expression, the default |
1264 | // case is that there is no operator and we will fall |
1265 | // through. |
1266 | std::vector<Expr> exprs; |
1267 | Expr expr = WithSpan<Expr>([this] { return ParseCallExpr(); }); |
1268 | |
1269 | exprs.push_back(expr); |
1270 | |
1271 | // Now we parse an optional op. |
1272 | std::vector<Rule> ops; |
1273 | |
1274 | // We will now parse 0 or more operator occurrences. |
1275 | while (true) { |
1276 | auto opt_op = ParseOp(); |
1277 | |
1278 | // If we didn't parse one we done. |
1279 | if (opt_op.size() == 0) { |
1280 | break; |
1281 | } |
1282 | |
1283 | // Read the operation we parsed; |
1284 | auto op = opt_op[0]; |
1285 | |
1286 | Expr right = WithSpan<Expr>([this] { return ParseCallExpr(); }); |
1287 | ICHECK(right->span.defined()); |
1288 | |
1289 | // If the operator stack is empty |
1290 | // we parse an operator and expression |
1291 | // and push them to stacks, then |
1292 | // continue. |
1293 | if (ops.size() == 0) { |
1294 | ops.push_back(op); |
1295 | exprs.push_back(right); |
1296 | continue; |
1297 | } |
1298 | |
1299 | if (op.precedence > ops.back().precedence || |
1300 | (op.precedence == ops.back().precedence && op.left_assoc == false)) { |
1301 | ops.push_back(op); |
1302 | exprs.push_back(right); |
1303 | continue; |
1304 | } |
1305 | |
1306 | while (ops.size() && (op.precedence < ops.back().precedence || |
1307 | (op.precedence == ops.back().precedence && op.left_assoc == true))) { |
1308 | Rule new_op = ops.back(); |
1309 | ops.pop_back(); |
1310 | Expr right = exprs.back(); |
1311 | exprs.pop_back(); |
1312 | Expr left = exprs.back(); |
1313 | exprs.pop_back(); |
1314 | ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; |
1315 | exprs.push_back( |
1316 | relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); |
1317 | } |
1318 | |
1319 | exprs.push_back(right); |
1320 | ops.push_back(op); |
1321 | } |
1322 | |
1323 | while (ops.size()) { |
1324 | Rule new_op = ops.back(); |
1325 | ops.pop_back(); |
1326 | Expr right = exprs.back(); |
1327 | exprs.pop_back(); |
1328 | Expr left = exprs.back(); |
1329 | exprs.pop_back(); |
1330 | ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; |
1331 | exprs.push_back( |
1332 | relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); |
1333 | } |
1334 | |
1335 | ICHECK_EQ(ops.size(), 0) << "No operations should be left on the operation stack." ; |
1336 | |
1337 | ICHECK_EQ(exprs.size(), 1) |
1338 | << "Only a single expression should be left on the expression stack." ; |
1339 | |
1340 | return exprs[0]; |
1341 | }); |
1342 | } |
1343 | |
1344 | ObjectRef ParseAttributeValue() { |
1345 | VLOG(9) << "Parser::ParseAttributeValue" ; |
1346 | auto next = Peek(); |
1347 | switch (next->token_type) { |
1348 | case TokenType::kFloat: |
1349 | case TokenType::kInteger: |
1350 | case TokenType::kBoolean: |
1351 | case TokenType::kStringLiteral: |
1352 | return Match(next->token_type)->data; |
1353 | case TokenType::kMetaReference: |
1354 | return ParseMetaRef(); |
1355 | case TokenType::kLSquare: { |
1356 | return ParseSequence<ObjectRef>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, |
1357 | [&]() { return ParseAttributeValue(); }); |
1358 | } |
1359 | case TokenType::kOpenParen: { |
1360 | // TODO(@jroesch: need to figure out bracket vs. sequence) |
1361 | // return ParseSequence<ObjectRef>(TokenType::kOpenParen, TokenType::kComma, |
1362 | // TokenType::kCloseParen, |
1363 | // [&]() { return ParseAttributeValue(); }); |
1364 | return Bracket<ObjectRef>(TokenType::kOpenParen, TokenType::kCloseParen, |
1365 | [&]() { return ParseAttributeValue(); }); |
1366 | } |
1367 | // TODO(@jroesch): not sure about this being the right way to handle nulls. |
1368 | case TokenType::kIdentifier: { |
1369 | if (auto text = next->data.as<tvm::StringObj>()) { |
1370 | std::string id = GetRef<String>(text); |
1371 | if (id == "nullptr" ) { |
1372 | Match(TokenType::kIdentifier); |
1373 | return ObjectRef(); |
1374 | } |
1375 | if (id == "None" ) { |
1376 | Match(TokenType::kIdentifier); |
1377 | return Optional<ObjectRef>(); |
1378 | } |
1379 | } |
1380 | } |
1381 | default: |
1382 | return ParseAtomicExpr(); |
1383 | } |
1384 | } |
1385 | |
1386 | Map<String, ObjectRef> ParseAttrs() { |
1387 | VLOG(9) << "Parser::ParseAttrs" ; |
1388 | Map<String, ObjectRef> kwargs; |
1389 | while (Peek()->token_type == TokenType::kIdentifier) { |
1390 | auto key = GetHierarchicalName(ParseHierarchicalName().data); |
1391 | Match(TokenType::kEqual); |
1392 | // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. |
1393 | auto value = ParseAttributeValue(); |
1394 | // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text |
1395 | // format is bad. |
1396 | kwargs.Set(key, value); |
1397 | WhenMatch(TokenType::kComma); |
1398 | } |
1399 | VLOG(9) << "Parser::ParseAttrs: kwargs=" << kwargs; |
1400 | return kwargs; |
1401 | } |
1402 | |
1403 | Expr ParseCallArgs(Expr op) { |
1404 | ICHECK(op.defined()) << "the operator must be defined" ; |
1405 | |
1406 | VLOG(9) << "Parser::ParseCallArgs" ; |
1407 | Attrs attrs; |
1408 | std::string op_key; |
1409 | bool is_op = false; |
1410 | |
1411 | if (auto op_node = op.as<OpNode>()) { |
1412 | is_op = true; |
1413 | op_key = op_node->attrs_type_key; |
1414 | } |
1415 | |
1416 | if (Peek()->token_type == TokenType::kOpenParen) { |
1417 | Array<Expr> args = ParseSequence<Expr>( |
1418 | TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, |
1419 | [&] { return ParseExpr(); }, |
1420 | [&] { |
1421 | auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; |
1422 | auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; |
1423 | auto is_pretty_attrs = is_ident && next_is_equal; |
1424 | auto is_meta_next = Lookahead(1)->token_type == TokenType::kMetaReference; |
1425 | // TODO(@jroesch): might not handle trailing comma |
1426 | auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen; |
1427 | auto is_meta_attrs = is_meta_next && last_meta; |
1428 | |
1429 | if (is_pretty_attrs || is_meta_attrs) { |
1430 | if (is_meta_attrs) { |
1431 | auto meta_ref = ParseMetaRef(); |
1432 | if (meta_ref.as<BaseAttrsNode>()) { |
1433 | attrs = Downcast<Attrs>(meta_ref); |
1434 | } else { |
1435 | // Not awesome parsing code here. |
1436 | this->pos--; |
1437 | return false; |
1438 | } |
1439 | } else { |
1440 | auto raw_attrs = ParseAttrs(); |
1441 | if (is_op && op_key.size()) { |
1442 | auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); |
1443 | ICHECK(attr_obj.defined()); |
1444 | attrs = Downcast<Attrs>(attr_obj); |
1445 | } else if (raw_attrs.count("attrs_type_key" )) { |
1446 | String attr_key = Downcast<String>(raw_attrs["attrs_type_key" ]); |
1447 | if (attr_key.size()) { |
1448 | raw_attrs.erase("attrs_type_key" ); |
1449 | auto attr_obj = |
1450 | tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); |
1451 | ICHECK(attr_obj.defined()); |
1452 | attrs = Downcast<Attrs>(attr_obj); |
1453 | } |
1454 | } else { |
1455 | this->diag_ctx.EmitFatal(Diagnostic::Error(op->span) |
1456 | << "unable to determine the 'attrs_type_key' with which " |
1457 | "to represent the call attributes for this operator" ); |
1458 | } |
1459 | } |
1460 | return true; |
1461 | } |
1462 | return false; |
1463 | }); |
1464 | |
1465 | if (!attrs.defined()) { |
1466 | if (is_op && op_key.size()) { |
1467 | auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, {}); |
1468 | ICHECK(attr_obj.defined()); |
1469 | attrs = Downcast<Attrs>(attr_obj); |
1470 | } |
1471 | } |
1472 | |
1473 | // TODO(@jroesch): in a secondary pass adjust spans. |
1474 | return Expr(Call(op, args, attrs, {})); |
1475 | } else { |
1476 | return Expr(); |
1477 | } |
1478 | |
1479 | return Expr(); |
1480 | } |
1481 | |
1482 | Expr ParseCallExpr() { |
1483 | VLOG(9) << "Parser::ParseCallExpr" ; |
1484 | return WithSpan<Expr>([this] { |
1485 | Expr expr = ParseAtomicExpr(); |
1486 | // Parse as many call args as possible, building up expression |
1487 | // |
1488 | // NB(@jroesch): this seems like a hack but in order to parse curried functions |
1489 | // and avoid complex grammar we will parse multiple call lists in a row. |
1490 | while (Peek()->token_type == TokenType::kOpenParen) { |
1491 | auto new_expr = ParseCallArgs(expr); |
1492 | |
1493 | if (new_expr.defined()) { |
1494 | expr = new_expr; |
1495 | } else { |
1496 | break; |
1497 | } |
1498 | } |
1499 | |
1500 | // We need a zero-arity case for constructors. |
1501 | if (auto ctor_node = expr.as<ConstructorNode>()) { |
1502 | if (ctor_node->inputs.size() == 0) { |
1503 | return Expr(Call(expr, {})); |
1504 | } |
1505 | } |
1506 | |
1507 | return expr; |
1508 | }); |
1509 | } |
1510 | |
1511 | Expr GetOp(const std::string& op_name, const Span& span) { |
1512 | VLOG(9) << "op_name=" << op_name << " span=" << span; |
1513 | try { |
1514 | return Op::Get(op_name); |
1515 | } catch (const Error& e) { |
1516 | // we can relax this, but probably need to relax checks or return non-null here. |
1517 | this->diag_ctx.EmitFatal(Diagnostic::Error(span) |
1518 | << "operator `" << op_name |
1519 | << "` not found, perhaps you forgot to register it?" ); |
1520 | return Expr(); |
1521 | } |
1522 | } |
1523 | |
1524 | Expr ParseAtomicExpr() { |
1525 | VLOG(9) << "Parser::ParseAtomicExpr" ; |
1526 | Expr expr = WithSpan<Expr>([this] { |
1527 | auto next = Peek(); |
1528 | switch (next->token_type) { |
1529 | case TokenType::kInteger: |
1530 | case TokenType::kFloat: { |
1531 | Consume(next->token_type); |
1532 | auto number = NumberToNDArray(next); |
1533 | Expr e = Constant(number, next->span); |
1534 | ICHECK(e->span.defined()) << "constant spans must be defined" ; |
1535 | return e; |
1536 | } |
1537 | case TokenType::kBoolean: { |
1538 | Consume(TokenType::kBoolean); |
1539 | int64_t value = Downcast<tvm::Integer>(next->data).IntValue(); |
1540 | Expr e = Constant(support::BoolToNDArray(value), next->span); |
1541 | ICHECK(e->span.defined()) << "constant spans must be defined" ; |
1542 | return e; |
1543 | } |
1544 | // Parse a local of the form `%x`. |
1545 | case TokenType::kLocal: { |
1546 | Consume(TokenType::kLocal); |
1547 | return Expr(LookupLocal(next)); |
1548 | } |
1549 | // Parse a local of the form `@x`. |
1550 | case TokenType::kGlobal: { |
1551 | auto global_name = next.ToString(); |
1552 | Consume(TokenType::kGlobal); |
1553 | auto global = AddOrGet(&global_names, global_name); |
1554 | return Expr(global); |
1555 | } |
1556 | // Parse a local of the form `x`. |
1557 | // Right now we fail to parse `x.y`. |
1558 | case TokenType::kIdentifier: { |
1559 | auto ctor = ctors.Get(next.ToString()); |
1560 | if (ctor) { |
1561 | Consume(TokenType::kIdentifier); |
1562 | return Expr(ctor.value()); |
1563 | } else { |
1564 | auto spanned_idents = ParseHierarchicalName(); |
1565 | auto idents = spanned_idents.data; |
1566 | auto span = spanned_idents.span; |
1567 | return GetOp(GetHierarchicalName(idents), span); |
1568 | } |
1569 | } |
1570 | case TokenType::kGraph: { |
1571 | Consume(TokenType::kGraph); |
1572 | return LookupGraphBinding(next); |
1573 | } |
1574 | case TokenType::kMetaReference: { |
1575 | return Downcast<Expr>(ParseMetaRef()); |
1576 | } |
1577 | case TokenType::kFn: { |
1578 | Consume(TokenType::kFn); |
1579 | Expr e = ParseFunctionDef(); |
1580 | ICHECK(e->span.defined()) << "function spans must be defined.\n" << e; |
1581 | return e; |
1582 | } |
1583 | case TokenType::kIf: { |
1584 | Expr e = ParseIf(); |
1585 | return e; |
1586 | } |
1587 | case TokenType::kRef: { |
1588 | Consume(TokenType::kRef); |
1589 | Match(TokenType::kOpenParen); |
1590 | auto ref_value = ParseExpr(); |
1591 | Match(TokenType::kCloseParen); |
1592 | return static_cast<Expr>(RefCreate(ref_value)); |
1593 | } |
1594 | case TokenType::kRefRead: { |
1595 | return WithSpan<Expr>([&]() { |
1596 | Consume(TokenType::kRefRead); |
1597 | Match(TokenType::kOpenParen); |
1598 | auto ref = ParseExpr(); |
1599 | Match(TokenType::kCloseParen); |
1600 | return static_cast<Expr>(RefRead(ref)); |
1601 | }); |
1602 | } |
1603 | case TokenType::kRefWrite: { |
1604 | return WithSpan<Expr>([&]() { |
1605 | Consume(TokenType::kRefWrite); |
1606 | Match(TokenType::kOpenParen); |
1607 | auto ref = ParseExpr(); |
1608 | Match(TokenType::kComma); |
1609 | auto value = ParseExpr(); |
1610 | Match(TokenType::kCloseParen); |
1611 | return static_cast<Expr>(RefWrite(ref, value)); |
1612 | }); |
1613 | } |
1614 | case TokenType::kOpenParen: { |
1615 | Span sp = next->span; |
1616 | Consume(TokenType::kOpenParen); |
1617 | // parse '(' ')' |
1618 | if (WhenMatch(TokenType::kCloseParen)) { |
1619 | return Expr(Tuple(Array<Expr>())); |
1620 | } else { |
1621 | Expr subexpr = ParseExpr(); |
1622 | // parse '(' expr ')' |
1623 | if (WhenMatch(TokenType::kCloseParen)) { |
1624 | return subexpr; |
1625 | // parse '( expr ',' * ')' |
1626 | } else if (WhenMatch(TokenType::kComma)) { |
1627 | Array<Expr> exprs = {subexpr}; |
1628 | while (true) { |
1629 | if (WhenMatch(TokenType::kCloseParen)) { |
1630 | break; |
1631 | } else { |
1632 | auto element = ParseExpr(); |
1633 | auto comma = Peek(); |
1634 | if (WhenMatch(TokenType::kComma)) { |
1635 | sp = sp.Merge(element->span.Merge(comma->span)); |
1636 | } else { |
1637 | sp = sp.Merge(element->span); |
1638 | } |
1639 | exprs.push_back(element); |
1640 | } |
1641 | } |
1642 | Expr tuple = Tuple(exprs, sp); |
1643 | ICHECK(tuple->span.defined()) << "tuple span should be defined" ; |
1644 | return tuple; |
1645 | } |
1646 | } |
1647 | } |
1648 | default: { |
1649 | this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) |
1650 | << "expected an expression found " << Pretty(next->token_type)); |
1651 | return Expr(); |
1652 | } |
1653 | } |
1654 | }); |
1655 | |
1656 | if (WhenMatch(TokenType::kPeriod)) { |
1657 | auto token = Match(TokenType::kInteger); |
1658 | auto index = token.ToNumber(); |
1659 | auto span = token->span.Merge(expr->span); |
1660 | VLOG(9) << "Parser::ParseAtomicExpr: tuple get item" ; |
1661 | return relay::TupleGetItem(expr, index, span); |
1662 | } else { |
1663 | return expr; |
1664 | } |
1665 | } |
1666 | |
1667 | /*! \brief Parse a hierarchical name. |
1668 | * |
1669 | * The tokenizer produces a token stream of <id1> . <id2> |
1670 | * and so on for names of the form `nn.conv2d`. |
1671 | * Currently we only use string names everywhere instead |
1672 | * of a notion of a hierarchical name. |
1673 | * |
1674 | * The below utility reassembles a token stream into a |
1675 | * single stream inserting the required periods needed |
1676 | * to look up registered names. |
1677 | */ |
1678 | Spanned<Array<String>> ParseHierarchicalName() { |
1679 | Array<String> idents; |
1680 | Span span; |
1681 | while (Peek()->token_type == TokenType::kIdentifier) { |
1682 | auto token = Peek(); |
1683 | |
1684 | if (span.defined()) { |
1685 | span = span.Merge(token->span); |
1686 | } else { |
1687 | span = token->span; |
1688 | } |
1689 | |
1690 | auto name = token.ToString(); |
1691 | idents.push_back(name); |
1692 | Consume(TokenType::kIdentifier); |
1693 | |
1694 | // Keep parsing while we see a trailing period. |
1695 | if (Peek()->token_type == TokenType::kPeriod) { |
1696 | Consume(TokenType::kPeriod); |
1697 | continue; |
1698 | } else { |
1699 | // No more periods means we are done! |
1700 | break; |
1701 | } |
1702 | } |
1703 | |
1704 | return Spanned<Array<String>>(idents, span); |
1705 | } |
1706 | |
1707 | std::string GetHierarchicalName(Array<String> idents) { |
1708 | ICHECK_NE(idents.size(), 0); |
1709 | std::stringstream hierarchical_name; |
1710 | int i = 0; |
1711 | int periods = idents.size() - 1; |
1712 | for (auto ident : idents) { |
1713 | hierarchical_name << ident; |
1714 | if (i < periods) { |
1715 | hierarchical_name << "." ; |
1716 | i++; |
1717 | } |
1718 | } |
1719 | return hierarchical_name.str(); |
1720 | } |
1721 | |
1722 | /*! \brief Parse a shape. */ |
1723 | Array<tvm::PrimExpr> ParseShape() { |
1724 | auto dims = ParseSequence<tvm::PrimExpr>( |
1725 | TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { |
1726 | tvm::PrimExpr dim; |
1727 | if (Peek()->token_type == TokenType::kMetaReference) { |
1728 | dim = Downcast<tvm::PrimExpr>(ParseMetaRef()); |
1729 | } else if (WhenMatch(TokenType::kQuestion)) { |
1730 | dim = tvm::tir::Any(); |
1731 | } else { |
1732 | dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data); |
1733 | } |
1734 | |
1735 | return dim; |
1736 | }); |
1737 | return dims; |
1738 | } |
1739 | |
1740 | /*! \brief Parse a function type. */ |
1741 | Type ParseFunctionType() { |
1742 | auto ty_params = ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma, |
1743 | TokenType::kCloseParen, [&]() { return ParseType(); }); |
1744 | |
1745 | Match(TokenType::kMinus); |
1746 | Match(TokenType::kRAngle); |
1747 | auto ret_type = ParseType(); |
1748 | |
1749 | return relay::FuncType(ty_params, ret_type, {}, {}); |
1750 | } |
1751 | |
1752 | // Parses a user defined ADT or type variable. |
1753 | Type ParseNonPrimitiveType(const Token& tok) { |
1754 | return WithSpan<Type>([&]() { |
1755 | auto name = tok.ToString(); |
1756 | Type head_type = LookupTypeVar(tok); |
1757 | |
1758 | if (!head_type.defined()) { |
1759 | // head_type = type_names.Get(name); |
1760 | head_type = AddOrGet(&type_names, name, TypeKind::kAdtHandle); |
1761 | } |
1762 | |
1763 | if (!head_type.defined()) { |
1764 | diag_ctx.EmitFatal(Diagnostic::Error(tok->span) |
1765 | << "the type constructor `" << name << "` is undefined" ); |
1766 | } |
1767 | |
1768 | Array<Type> arg_types; |
1769 | if (Peek()->token_type == TokenType::kLSquare) { |
1770 | arg_types = ParseSequence<Type>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, |
1771 | [&]() { return ParseType(); }); |
1772 | } |
1773 | |
1774 | if (arg_types.size()) { |
1775 | return static_cast<Type>(TypeCall(head_type, arg_types)); |
1776 | } else { |
1777 | if (head_type.as<GlobalTypeVarNode>()) { |
1778 | return static_cast<Type>(TypeCall(head_type, {})); |
1779 | } else { |
1780 | return static_cast<Type>(head_type); |
1781 | } |
1782 | } |
1783 | }); |
1784 | } |
1785 | |
1786 | /*! \brief Parses a TVM type. |
1787 | * |
1788 | * This matches either a `Tensor[shape, dtype]`, a user defined ADT, a tuple type, |
1789 | * a scalar type or an incomplete type `_`. |
1790 | */ |
1791 | Type ParseType() { |
1792 | return WithSpan<Type>([&]() -> Type { |
1793 | auto tok = Peek(); |
1794 | |
1795 | if (tok->token_type == TokenType::kOpenParen) { |
1796 | auto tys = |
1797 | ParseSequence<relay::Type>(TokenType::kOpenParen, TokenType::kComma, |
1798 | TokenType::kCloseParen, [&]() { return ParseType(); }); |
1799 | return relay::TupleType(tys); |
1800 | } else if (WhenMatch(TokenType::kFn)) { |
1801 | return ParseFunctionType(); |
1802 | } else if (WhenMatch(TokenType::kIdentifier)) { |
1803 | auto id = tok.ToString(); |
1804 | if (id == "Tensor" ) { |
1805 | Match(TokenType::kLSquare); |
1806 | auto shape = ParseShape(); |
1807 | Match(TokenType::kComma); |
1808 | auto dtype_tok = Match(TokenType::kIdentifier); |
1809 | auto dtype = DataType(String2DLDataType(dtype_tok.ToString())); |
1810 | Match(TokenType::kRSquare); |
1811 | return TensorType(shape, dtype); |
1812 | } else { |
1813 | auto ty = tok.ToString(); |
1814 | if (ty.rfind("int" , 0) == 0 || ty.find("float" , 0) == 0 || ty.find("uint" , 0) == 0 || |
1815 | ty.find("bool" , 0) == 0) { |
1816 | // Need to do better error handling here. |
1817 | auto dtype = DataType(String2DLDataType(tok.ToString())); |
1818 | return TensorType({}, dtype); |
1819 | } else { |
1820 | return ParseNonPrimitiveType(tok); |
1821 | } |
1822 | } |
1823 | } else if (WhenMatch(TokenType::kUnderscore)) { |
1824 | return IncompleteType(); |
1825 | } else { |
1826 | this->diag_ctx.EmitFatal(Diagnostic::Error(tok->span) |
1827 | << "failed to parse type found " << tok); |
1828 | return Type(); |
1829 | } |
1830 | }); |
1831 | } |
1832 | |
1833 | template <typename R> |
1834 | R ConsumeWhitespace(std::function<R()> func) { |
1835 | auto old = this->ignore_whitespace; |
1836 | this->ignore_whitespace = true; |
1837 | while (tokens[pos]->token_type == TokenType::kWhitespace) { |
1838 | pos++; |
1839 | } |
1840 | auto res = func(); |
1841 | this->ignore_whitespace = old; |
1842 | return res; |
1843 | } |
1844 | |
1845 | Map<String, Array<ObjectRef>> ParseMetadata() { |
1846 | if (Peek()->token_type == TokenType::kMetadata) { |
1847 | return Match(TokenType::kMetadata).ToMetadata(); |
1848 | } else { |
1849 | return Map<String, Array<ObjectRef>>(); |
1850 | } |
1851 | } |
1852 | |
1853 | /*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */ |
1854 | void DisplayNextN(int n) { |
1855 | std::cout << "remaining tokens: " << std::endl; |
1856 | auto bound = std::min(pos + n, static_cast<int>(tokens.size())); |
1857 | for (int i = 0; i < bound - pos; i++) { |
1858 | std::cout << tokens[pos + i] << std::endl; |
1859 | } |
1860 | } |
1861 | |
1862 | // A function for debugging the operator parser. |
1863 | void DebugStack(const std::vector<Expr>& exprs, const std::vector<Rule>& rules) { |
1864 | std::cout << "Expr Stack: " ; |
1865 | for (auto expr : exprs) { |
1866 | std::cout << expr << ", " ; |
1867 | } |
1868 | |
1869 | std::cout << std::endl; |
1870 | std::cout << "Op Stack: " ; |
1871 | for (auto rule : rules) { |
1872 | std::cout << rule.op << ", " ; |
1873 | } |
1874 | |
1875 | std::cout << std::endl; |
1876 | } |
1877 | }; |
1878 | |
1879 | Parser InitParser(const std::string& file_name, const std::string& file_content, |
1880 | const Optional<IRModule>& init_module, const MetaTable& init_meta_table) { |
1881 | VLOG(9) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size(); |
1882 | SourceName src_name = SourceName::Get(file_name); |
1883 | Source source(src_name, file_content); |
1884 | |
1885 | IRModule module; |
1886 | if (!init_module) { |
1887 | SourceMap source_map; |
1888 | module = IRModule({}, {}, {}, source_map); |
1889 | } else { |
1890 | module = init_module.value(); |
1891 | } |
1892 | |
1893 | module->source_map.Add(source); |
1894 | |
1895 | auto diag_ctx = DiagnosticContext::Default(module); |
1896 | auto tokens_and_table = Tokenize(diag_ctx, source); |
1897 | |
1898 | auto tokens = tokens_and_table.first; |
1899 | MetaTable meta_data_table = tokens_and_table.second.ToMetadata(); |
1900 | |
1901 | // Merge any entries in init_meta_table into anything captured in the #[metadata] section |
1902 | // of the file_content. Metadata references within file_content must use indexes which account |
1903 | // for this ordering. |
1904 | for (const auto& pair : init_meta_table) { |
1905 | Array<ObjectRef> items; |
1906 | if (meta_data_table.count(pair.first)) { |
1907 | items = meta_data_table[pair.first]; |
1908 | } |
1909 | for (const auto& obj : pair.second) { |
1910 | items.push_back(obj); |
1911 | } |
1912 | meta_data_table.Set(pair.first, items); |
1913 | } |
1914 | |
1915 | return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table)); |
1916 | } |
1917 | |
1918 | IRModule ParseModule(const std::string& file_name, const std::string& file_content, |
1919 | const Optional<IRModule>& init_module, const MetaTable& init_meta_table) { |
1920 | VLOG_CONTEXT << "ParseModule" ; |
1921 | VLOG(9) << "parsing and type-checking " << file_name; |
1922 | auto parser = InitParser(file_name, file_content, init_module, init_meta_table); |
1923 | auto mod = parser.ParseModule(); |
1924 | ICHECK(mod.defined()) << "The parser must return a non-null module." ; |
1925 | // NB(@jroesch): it is very important that we render any errors before we proceed |
1926 | // if there were any errors which allow the parser to proceed we must render them |
1927 | // here. |
1928 | parser.diag_ctx.Render(); |
1929 | auto infer_type = tvm::relay::transform::InferType(); |
1930 | ICHECK(infer_type.defined()) << "The type inferencer must be non-null." ; |
1931 | return infer_type(mod); |
1932 | } |
1933 | |
1934 | Expr ParseExpr(const std::string& file_name, const std::string& file_content) { |
1935 | VLOG(9) << "ParseExpr" ; |
1936 | auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable()); |
1937 | parser.ParseSemVer(false); |
1938 | parser.PushScope(); |
1939 | auto expr = parser.ParseExpr(); |
1940 | parser.Match(TokenType::kEndOfFile); |
1941 | // NB(@jroesch): it is very important that we render any errors before we proceed |
1942 | // if there were any errors which allow the parser to proceed we must render them |
1943 | // here. |
1944 | parser.diag_ctx.Render(); |
1945 | return expr; |
1946 | } |
1947 | |
1948 | /*! |
1949 | * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources |
1950 | * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for |
1951 | * modules constructed programaticaly rather than textually. |
1952 | */ |
1953 | Pass AnnotateSpans() { |
1954 | auto pass_func = [](const IRModule& mod, const PassContext& ctx) { |
1955 | String text = AsText(mod, /*show_meta_data=*/true); |
1956 | VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text; |
1957 | return ParseModule("GeneratedSource" , text); |
1958 | }; |
1959 | return CreateModulePass(pass_func, 0, "AnnotateSpans" , {}); |
1960 | } |
1961 | |
1962 | TVM_REGISTER_GLOBAL("relay.parser.ParseModuleInContext" ) |
1963 | .set_body_typed([](const std::string& file_name, const std::string& file_content, |
1964 | const Optional<IRModule>& init_module, const MetaTable& init_meta_table) { |
1965 | return ParseModule(file_name, file_content, init_module, init_meta_table); |
1966 | }); |
1967 | |
1968 | TVM_REGISTER_GLOBAL("relay.parser.ParseModule" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
1969 | ICHECK(args.size() >= 2 && args.size() <= 4) << "Expected 2-4 arguments, but got " << args.size(); |
1970 | if (args.size() == 2) { |
1971 | *ret = ParseModule(args[0], args[1]); |
1972 | } else if (args.size() == 3) { |
1973 | *ret = ParseModule(args[0], args[1], args[2]); |
1974 | } else { |
1975 | *ret = ParseModule(args[0], args[1], args[2], args[3]); |
1976 | } |
1977 | }); |
1978 | |
1979 | TVM_REGISTER_GLOBAL("relay.parser.ParseExpr" ) |
1980 | .set_body_typed([](tvm::String file_name, tvm::String file_content) { |
1981 | return ParseExpr(file_name, file_content); |
1982 | }); |
1983 | |
1984 | TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans" ).set_body_typed(AnnotateSpans); |
1985 | |
1986 | } // namespace relay |
1987 | } // namespace tvm |
1988 | |