1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file parser.cc |
22 | * \brief A parser for TVM IR. |
23 | */ |
24 | #include <tvm/ir/module.h> |
25 | #include <tvm/node/reflection.h> |
26 | #include <tvm/parser/parser.h> |
27 | #include <tvm/relay/adt.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/function.h> |
30 | #include <tvm/relay/transform.h> |
31 | #include <tvm/runtime/logging.h> |
32 | #include <tvm/runtime/object.h> |
33 | #include <tvm/runtime/registry.h> |
34 | #include <tvm/target/virtual_device.h> |
35 | |
36 | #include <fstream> |
37 | |
38 | #include "../support/scalars.h" |
39 | #include "./meta_ref.h" |
40 | #include "./op_table.h" |
41 | #include "./span_check.h" |
42 | #include "./tokenizer.h" |
43 | #include "tvm/runtime/builtin_fp16.h" |
44 | |
45 | namespace tvm { |
46 | namespace parser { |
47 | |
48 | using namespace relay; |
49 | using Expr = relay::Expr; |
50 | |
51 | /*! \brief The meta table maps from type key to a sequence of objects. */ |
52 | using MetaTable = Map<String, Array<ObjectRef>>; |
53 | |
54 | using tvm::transform::CreateModulePass; |
55 | using tvm::transform::PassContext; |
56 | |
57 | /*! \brief A helper for passing around spans with data structures with |
58 | * no span field. |
59 | */ |
60 | template <typename T> |
61 | struct Spanned { |
62 | T data; |
63 | Span span; |
64 | |
65 | Spanned() = default; |
66 | Spanned(const Spanned<T>& other) = default; |
67 | Spanned(T data, Span span) : data(data), span(span) {} |
68 | }; |
69 | |
70 | /*! \brief A wrapper structure for capturing the result of parsing |
71 | * a global definition *before* we add it to the IRModule. |
72 | * |
73 | * This enables the parser to parse everything in one pass before |
74 | * constructing the IRModule. |
75 | */ |
76 | struct GlobalFunc { |
77 | GlobalVar global; |
78 | Function function; |
79 | GlobalFunc() : global(), function() {} |
80 | GlobalFunc(GlobalVar global, Function function) : global(global), function(function) {} |
81 | GlobalFunc(const GlobalFunc& gfunc) { |
82 | this->global = gfunc.global; |
83 | this->function = gfunc.function; |
84 | } |
85 | }; |
86 | |
87 | /*! \brief A wrapper structure for capturing all top-level definitions |
88 | * when parsing a module. |
89 | */ |
90 | struct Definitions { |
91 | /*! \brief The set of global functions. */ |
92 | std::vector<GlobalFunc> funcs; |
93 | /*! \brief The set of type definitions. */ |
94 | std::vector<TypeData> types; |
95 | // TODO(@jroesch): contain meta-table below |
96 | }; |
97 | |
98 | /*! \brief A structure representing the semantic versioning information |
99 | * for a Relay program. |
100 | */ |
101 | class SemVer { |
102 | public: |
103 | int major_version; |
104 | int minor_version; |
105 | int patch_version; |
106 | |
107 | SemVer() : major_version(0), minor_version(0), patch_version(0) {} |
108 | SemVer(int major_version, int minor_version, int patch_version) |
109 | : major_version(major_version), minor_version(minor_version), patch_version(patch_version) {} |
110 | SemVer(const SemVer& other) |
111 | : major_version(other.major_version), |
112 | minor_version(other.minor_version), |
113 | patch_version(other.patch_version) {} |
114 | }; |
115 | |
116 | /*! \brief A simple wrapper around a mapping from raw string names |
117 | * to a TVM variable, type variable or other binder type. |
118 | */ |
119 | template <typename T> |
120 | struct Scope { |
121 | /*! \brief The internal map. */ |
122 | std::unordered_map<std::string, T> name_map; |
123 | }; |
124 | |
125 | /*! \brief A stack of scopes. |
126 | * |
127 | * In order to properly handle scoping we must maintain a stack of scopes. |
128 | * |
129 | * A stack allows users to write programs which contain repeated variable |
130 | * names and to properly handle both nested scopes and removal of variables |
131 | * when they go out of scope. |
132 | * |
133 | * This is the classic approach to lexical scoping. |
134 | */ |
135 | template <typename T> |
136 | class ScopeStack { |
137 | private: |
138 | std::vector<Scope<T>> scope_stack; |
139 | std::unordered_map<std::string, T> free_vars; |
140 | |
141 | public: |
142 | /*! \brief Adds a variable binding to the current scope. */ |
143 | void Add(const std::string& name, const T& value) { |
144 | if (!this->scope_stack.size()) { |
145 | LOG(FATAL) << "internal issue" ; |
146 | } |
147 | this->scope_stack.back().name_map.insert({name, value}); |
148 | } |
149 | |
150 | void AddFreeVar(const std::string& name, const T& value) { free_vars.insert({name, value}); } |
151 | |
152 | /*! \brief Looks up a variable name in the scope stack returning the matching variable |
153 | * in most recent scope. */ |
154 | T Lookup(const std::string& name) { |
155 | for (auto scope = this->scope_stack.rbegin(); scope != this->scope_stack.rend(); ++scope) { |
156 | auto it = scope->name_map.find(name); |
157 | if (it != scope->name_map.end()) { |
158 | return it->second; |
159 | } |
160 | } |
161 | |
162 | // Check if we bound a free variable declaration. |
163 | auto it = free_vars.find(name); |
164 | if (it != free_vars.end()) { |
165 | return it->second; |
166 | } |
167 | |
168 | return T(); |
169 | } |
170 | |
171 | /*! \brief Adds a fresh scope. */ |
172 | void PushStack() { this->scope_stack.push_back(Scope<T>()); } |
173 | |
174 | /*! \brief Removes the most recent scope. */ |
175 | void PopStack() { this->scope_stack.pop_back(); } |
176 | }; |
177 | |
178 | struct DuplicateKeyError : public Error { |
179 | explicit DuplicateKeyError(const std::string& msg) : Error(msg) {} |
180 | }; |
181 | |
182 | /*! \brief A table of interning strings as global function and type names. */ |
183 | template <typename T> |
184 | struct InternTable { |
185 | /*! \brief The internal table mapping strings to a unique allocation. */ |
186 | std::unordered_map<std::string, T> table; |
187 | DiagnosticContext* ctx; |
188 | |
189 | /*! \brief Add the unique allocation. */ |
190 | void Add(const std::string& name, const T& t) { |
191 | auto it = table.find(name); |
192 | if (it != table.end()) { |
193 | throw DuplicateKeyError("duplicate key name in intern table" ); |
194 | } else { |
195 | table.insert({name, t}); |
196 | } |
197 | } |
198 | |
199 | /*! \brief Return the unique allocation. */ |
200 | Optional<T> Get(const std::string& name) const { |
201 | auto it = table.find(name); |
202 | if (it != table.end()) { |
203 | return Optional<T>(it->second); |
204 | } else { |
205 | return Optional<T>(); |
206 | } |
207 | } |
208 | }; |
209 | |
210 | GlobalVar AddOrGet(InternTable<GlobalVar>* table, const std::string& name) { |
211 | auto var = table->Get(name); |
212 | if (var) { |
213 | return var.value(); |
214 | } else { |
215 | auto gvar = GlobalVar(name); |
216 | table->Add(name, gvar); |
217 | return gvar; |
218 | } |
219 | } |
220 | |
221 | GlobalTypeVar AddOrGet(InternTable<GlobalTypeVar>* table, const std::string& name, |
222 | TypeKind kind = TypeKind::kType) { |
223 | auto var = table->Get(name); |
224 | if (var) { |
225 | auto tvar = var.value(); |
226 | TypeKind& tvar_kind = const_cast<TypeKind&>(tvar->kind); |
227 | tvar_kind = kind; |
228 | return tvar; |
229 | } else { |
230 | auto gvar = GlobalTypeVar(name, kind); |
231 | table->Add(name, gvar); |
232 | return gvar; |
233 | } |
234 | } |
235 | |
236 | /*! \brief The parser class is the main interface to the parser. |
237 | * the parser is not currently exposed beyond this .cc file. |
238 | * |
239 | * The parser is initialized with a diagnostic context, an |
240 | * operator table, and a token stream. |
241 | * |
242 | * The rest of the internal state is used to map the human readable |
243 | * form to in-memory IR representation. |
244 | * |
245 | * The main entry point to the parser are a set of parsing methods |
246 | * such as `ParseModule` and `ParseExpr`. |
247 | * |
248 | * As with traditional recursive descent parsers the parsing methods |
249 | * are factored recursively just as one would do with a formal language |
250 | * grammar. |
251 | * |
252 | * You can view a recursive descent parser as a human friendly way to specify |
253 | * a state machine, and thus this factoring is necessary as the 'state' of this |
254 | * machine is the combination of the current parsing method and the next token. |
255 | * |
256 | * Parsing proceeds by matching a token and then dispatching to the appropriate |
257 | * method to parse the next tokens in the stream. |
258 | * |
259 | * For example if we are parsing a type and encounter a "Tensor" token we switch |
260 | * into a mode for parsing `[`, a shape, a comma, a data type and then a `]`. |
261 | * |
262 | * Certain matches like this are unambiguous and proceed in a straight line fashion |
263 | * once the initial token is found. Other parsing is more complex and requires some |
264 | * tricks to correctly parse. |
265 | * |
266 | * For example when we find a '(' in an expression context, it may be part of |
267 | * a tuple, the arguments to a call, or a parenthesized expression. The below code |
268 | * disambiguate these cases by factoring expression parsing into a series of methods |
269 | * which encode the parsing context and thus how to interpret the parenthesis. |
270 | * |
271 | * For more information one should be able to read the code in order starting with |
272 | * `ParseModule` or `ParseExpr`. |
273 | */ |
274 | class Parser { |
275 | public: |
276 | /*! \brief The version that the parser is parsing. */ |
277 | SemVer version; |
278 | |
279 | /*! \brief The IRModule we are building. */ |
280 | IRModule module; |
281 | |
282 | /*! \brief The diagnostic context used for error reporting. */ |
283 | DiagnosticContext diag_ctx; |
284 | |
285 | const Source& source; |
286 | |
287 | /*! \brief The current position in the token stream. */ |
288 | int pos; |
289 | |
290 | /*! \brief The token stream for the parser. */ |
291 | std::vector<Token> tokens; |
292 | |
293 | /*! \brief The configured operator table. */ |
294 | OperatorTable op_table; |
295 | |
296 | /*! \brief Configure the whitespace mode, right now we ignore all whitespace. */ |
297 | bool ignore_whitespace; |
298 | |
299 | /*! \brief A global mapping for GlobalVar. */ |
300 | InternTable<GlobalVar> global_names; |
301 | |
302 | /*! \brief A global mapping for type definitions. */ |
303 | InternTable<GlobalTypeVar> type_names; |
304 | |
305 | /*! \brief A global mapping for constructor names. */ |
306 | InternTable<Constructor> ctors; |
307 | |
308 | /*! \brief A mapping from graph variable to expression, i.e., `%0 = expr`. */ |
309 | std::unordered_map<int, Expr> graph_ctx; |
310 | |
311 | /*! \brief The set of type scopes used for generics. */ |
312 | ScopeStack<TypeVar> type_scopes; |
313 | |
314 | /*! \brief The set of expression scopes used for lexical scope. */ |
315 | ScopeStack<Var> expr_scopes; |
316 | |
317 | /*! \brief The metadata section. */ |
318 | MetaTable meta_table; |
319 | |
320 | Parser(IRModule module, DiagnosticContext ctx, const Source& source, std::vector<Token> tokens, |
321 | OperatorTable op_table, MetaTable table) |
322 | : module(module), |
323 | diag_ctx(ctx), |
324 | source(source), |
325 | pos(0), |
326 | tokens(tokens), |
327 | op_table(op_table), |
328 | ignore_whitespace(true), |
329 | meta_table(table) { |
330 | InitializeGlobals(); |
331 | InitializeTypeDefs(); |
332 | } |
333 | |
334 | /*! If we are parsing into a module with previously loaded data types we need to |
335 | * map constructor names and variable names in the global tables. |
336 | */ |
337 | void InitializeTypeDefs() { |
338 | for (auto pair : this->module->type_definitions) { |
339 | type_names.Add(pair.first->name_hint, pair.first); |
340 | for (auto ctor : pair.second->constructors) { |
341 | ctors.Add(ctor->name_hint, ctor); |
342 | } |
343 | } |
344 | } |
345 | |
346 | void InitializeGlobals() { |
347 | for (auto pair : this->module->functions) { |
348 | global_names.Add(pair.first->name_hint, pair.first); |
349 | } |
350 | } |
351 | |
352 | /*! \brief Examine the next token in the stream, the current parser is configured to be |
353 | * whitespace insensitive so we will skip all whitespace or comment tokens. */ |
354 | Token Peek() { |
355 | // For now we ignore all whitespace tokens and comments. |
356 | // We can tweak this behavior later to enable white space sensitivity in the parser. |
357 | while (pos < static_cast<int64_t>(tokens.size()) && ignore_whitespace && |
358 | (tokens.at(pos)->token_type == TokenType::kWhitespace || |
359 | tokens.at(pos)->token_type == TokenType::kNewline || |
360 | tokens.at(pos)->token_type == TokenType::kLineComment || |
361 | tokens.at(pos)->token_type == TokenType::kComment)) { |
362 | pos++; |
363 | } |
364 | |
365 | if (pos < static_cast<int64_t>(tokens.size())) { |
366 | return Token(this->tokens.at(pos)); |
367 | } else { |
368 | return Token::Null(); |
369 | } |
370 | } |
371 | |
372 | /*! \brief Lookahead by N tokens. |
373 | * \param n The number of tokens to lookahead. |
374 | * \return The Nth token. |
375 | */ |
376 | Token Lookahead(int n) { |
377 | ICHECK_GE(n, 1) << "lookahead is only valid when n >= 1" ; |
378 | |
379 | // We intend to skip n - 1 tokens, then return the nth. |
380 | auto old_pos = pos; |
381 | for (int i = 0; i < n - 1; i++) { |
382 | Peek(); |
383 | pos++; |
384 | } |
385 | |
386 | auto tok = Peek(); |
387 | pos = old_pos; |
388 | return tok; |
389 | } |
390 | |
391 | /*! \brief Consume a token, this method is the lowest level way to consume a token |
392 | * and will not ignore white space or look ahead in anyway. |
393 | * |
394 | * /param token_type The token type to match. |
395 | */ |
396 | void Consume(const TokenType& token_type) { |
397 | if (tokens[pos]->token_type != token_type) { |
398 | this->diag_ctx.EmitFatal(Diagnostic::Error(tokens[pos]->span) |
399 | << "expected a " << Pretty(token_type) << " found " |
400 | << Pretty(Peek()->token_type)); |
401 | } |
402 | pos++; |
403 | } |
404 | |
405 | /*! Match a token in the stream, this will first invoke Peek, ignoring tokens such |
406 | * as whitespace or comments returning the first meaningful token. |
407 | * |
408 | * We then try and consume the requested token, this will trigger an error if the |
409 | * current token does not match the token_type. |
410 | */ |
411 | Token Match(const TokenType& token_type) { |
412 | auto tok = Peek(); |
413 | Consume(token_type); |
414 | return tok; |
415 | } |
416 | |
417 | /*! Conditionally consume a token when it matches, this will never trigger an error |
418 | * as we guard against consuming the token before we do. |
419 | * |
420 | * Useful for matching optional tokens, effectively looksahead by one. |
421 | */ |
422 | bool WhenMatch(const TokenType& token_type) { |
423 | VLOG(9) << "Parser::WhenMatch: Peek() == " << Peek(); |
424 | if (Peek()->token_type == token_type) { |
425 | Consume(token_type); |
426 | return true; |
427 | } else { |
428 | return false; |
429 | } |
430 | } |
431 | |
432 | /* \brief Add a graph binding to the parsing context |
433 | * |
434 | * For example if we parse %0 = add(...), map 0 -> add(...), etc. |
435 | */ |
436 | void AddGraphBinding(const Token& token, const Expr& expr) { |
437 | auto graph_no = token.ToNumber(); |
438 | this->graph_ctx.insert({graph_no, expr}); |
439 | } |
440 | |
441 | /* \brief Lookup a previously bound graph variable. |
442 | * |
443 | * Note: we take tokens in all lookup methods so that we |
444 | * that we can do error reporting based on token location. |
445 | */ |
446 | Expr LookupGraphBinding(const Token& token) { |
447 | auto graph_no = token.ToNumber(); |
448 | auto it = this->graph_ctx.find(graph_no); |
449 | if (it != this->graph_ctx.end()) { |
450 | return it->second; |
451 | } else { |
452 | LOG(FATAL) << "Local variable %" << graph_no << " has not yet been defined" ; |
453 | throw; |
454 | } |
455 | } |
456 | |
457 | /*! \brief Bind a local variable in the expression scope. |
458 | * |
459 | * "x" -> Var("x"), these are needed to map from the raw string names |
460 | * to unique variable nodes. |
461 | * If a virtual device is specified, sets the virtual device of the variable. |
462 | */ |
463 | Var BindVar(const std::string& name, const relay::Type& type_annotation, |
464 | Optional<VirtualDevice> virtual_device = Optional<VirtualDevice>()) { |
465 | auto var = Var(name, type_annotation); |
466 | var->virtual_device_ = virtual_device.value_or(VirtualDevice::FullyUnconstrained()); |
467 | VLOG(1) << "Binding var named " << name << " to variable node " << PrettyPrint(var); |
468 | this->expr_scopes.Add(name, var); |
469 | return var; |
470 | } |
471 | |
472 | /*! \brief Bind a local variable in the expression scope. |
473 | * |
474 | * "x" -> Var("x"), these are needed to map from the raw string names |
475 | * to unique variable nodes. |
476 | */ |
477 | Var BindFreeVar(const std::string& name, const relay::Type& type_annotation) { |
478 | auto var = Var(name, type_annotation); |
479 | this->expr_scopes.AddFreeVar(name, var); |
480 | return var; |
481 | } |
482 | |
483 | /*! \brief Bind a type variable in the type scope. |
484 | * |
485 | * "A" -> TypeVar("A", ...), these are needed to map from raw string names |
486 | * to unique type variable nodes. |
487 | */ |
488 | TypeVar BindTypeVar(const std::string& name, const TypeKind type_kind) { |
489 | auto type_var = TypeVar(name, type_kind); |
490 | this->type_scopes.Add(name, type_var); |
491 | return type_var; |
492 | } |
493 | |
494 | /*! \brief Lookup a variable in the expression scope. |
495 | * |
496 | * Note: all lookup methods take tokens intentionally for error reporting information. |
497 | */ |
498 | Var LookupLocal(const Token& local) { |
499 | auto var = this->expr_scopes.Lookup(local.ToString()); |
500 | if (!var.defined()) { |
501 | diag_ctx.Emit(Diagnostic::Error(local->span) |
502 | << "this local variable has not been previously declared" ); |
503 | } |
504 | return var; |
505 | } |
506 | |
507 | /*! \brief Lookup a variable in the type scope. |
508 | * |
509 | * Note: all lookup methods take tokens intentionally for error reporting information. |
510 | */ |
511 | TypeVar LookupTypeVar(const Token& ident) { |
512 | auto var = this->type_scopes.Lookup(ident.ToString()); |
513 | return var; |
514 | } |
515 | |
516 | /*! \brief Add an expression scope to the scope stack. */ |
517 | void PushScope() { this->expr_scopes.PushStack(); } |
518 | |
519 | /*! \brief Remove N expression scopes from the scope stack. */ |
520 | void PopScopes(int n) { |
521 | for (int i = 0; i < n; i++) { |
522 | this->expr_scopes.PopStack(); |
523 | } |
524 | } |
525 | |
526 | /*! \brief Add an type scope to the scope stack. */ |
527 | void PushTypeScope() { this->type_scopes.PushStack(); } |
528 | |
529 | /*! \brief Remove N type scopes from the scope stack. */ |
530 | void PopTypeScopes(int n) { |
531 | for (int i = 0; i < n; i++) { |
532 | this->type_scopes.PopStack(); |
533 | } |
534 | } |
535 | |
536 | /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */ |
537 | NDArray NumberToNDArray(const Token& token) { |
538 | if (token->token_type == TokenType::kInteger) { |
539 | return support::IntImmToNDArray(Downcast<tvm::IntImm>(token->data)); |
540 | } else if (token->token_type == TokenType::kFloat) { |
541 | return support::FloatImmToNDArray(Downcast<tvm::FloatImm>(token->data)); |
542 | } else { |
543 | LOG(FATAL) << "internal error: should only call this function on numeric tokens" ; |
544 | } |
545 | } |
546 | |
547 | [[noreturn]] void ParseError(const Token& token, const std::string& msg) { |
548 | throw std::runtime_error(msg); |
549 | } |
550 | |
551 | /*! \brief A parsing helper for a bracketed expression <start> <parser> <stop>. */ |
552 | template <typename R> |
553 | R Bracket(TokenType open, TokenType close, std::function<R()> parser) { |
554 | Match(open); |
555 | R result = parser(); |
556 | Match(close); |
557 | return result; |
558 | } |
559 | |
560 | /*! \brief Parse `(` parser() `)`. */ |
561 | template <typename R> |
562 | R Parens(std::function<R()> parser) { |
563 | return Bracket(TokenType::kOpenParen, TokenType::kCloseParen, parser); |
564 | } |
565 | |
566 | /*! \brief Parse `{` parser() `}`. */ |
567 | template <typename R> |
568 | R Block(std::function<R()> parser) { |
569 | return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser); |
570 | } |
571 | |
572 | template <typename R> |
573 | R WithSpan(std::function<R()> parser) { |
574 | auto start_span = Peek()->span; |
575 | VLOG(9) << "WithSpan: start_span = " << start_span; |
576 | R ast = parser(); |
577 | if (ast.defined()) { |
578 | // The token at the head of the stream is now 1 past where we parsed. So we find its start |
579 | // position as its start and end, so that when we merge we only grow the spanned region |
580 | // to the start of the current stream. |
581 | auto span_pos = pos - 1; |
582 | while ((tokens.at(span_pos)->token_type == TokenType::kWhitespace || |
583 | tokens.at(span_pos)->token_type == TokenType::kNewline || |
584 | tokens.at(span_pos)->token_type == TokenType::kLineComment || |
585 | tokens.at(span_pos)->token_type == TokenType::kComment)) { |
586 | span_pos--; |
587 | } |
588 | auto end_token = tokens.at(span_pos); |
589 | VLOG(9) << "WithSpan: end_span = " << end_token->span; |
590 | ast->span = start_span.Merge(end_token->span); |
591 | } |
592 | return ast; |
593 | } |
594 | |
595 | struct MetaRef { |
596 | std::string type_key; |
597 | uint64_t node_index; |
598 | Span span; |
599 | MetaRef(std::string type_key, uint64_t node_index, Span span) |
600 | : type_key(type_key), node_index(node_index), span(span) {} |
601 | }; |
602 | |
603 | MetaRef MetaRefFromToken(const Token& tok) { |
604 | Call ref = Downcast<Call>(tok->data); |
605 | auto attrs = ref->attrs.as<MetaRefAttrs>(); |
606 | auto type_key = attrs->node_type_key; |
607 | auto index = attrs->node_index; |
608 | return MetaRef(type_key, index, ref->span); |
609 | } |
610 | |
611 | /*! \brief Parse a meta reference of the form `meta[type_key][node_index]`. |
612 | * For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]` |
613 | * the second, and so on. |
614 | */ |
615 | ObjectRef ParseMetaRef() { |
616 | auto meta_ref_tok = Match(TokenType::kMetaReference); |
617 | auto meta_ref = MetaRefFromToken(meta_ref_tok); |
618 | auto it = this->meta_table.find(meta_ref.type_key); |
619 | if (it != this->meta_table.end()) { |
620 | auto nodes = (*it).second; |
621 | if (meta_ref.node_index < nodes.size()) { |
622 | return nodes[meta_ref.node_index]; |
623 | } else { |
624 | this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) |
625 | << "the node index `" << meta_ref.node_index |
626 | << "` is out of bounds for `" << meta_ref.type_key << "`" ); |
627 | return ObjectRef(); |
628 | } |
629 | } else { |
630 | this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) |
631 | << "no entry in the meta table for `" << meta_ref.type_key << "`" ); |
632 | return ObjectRef(); |
633 | } |
634 | } |
635 | /*! \brief Parses a sequence beginning with a start token, separated by a seperator token, and |
636 | * ending with a stop token. |
637 | * |
638 | * The simple form being <start> (<parse()> <seperator>)* <stop>. |
639 | * |
640 | * This also provides a fourth argument which is allowed to run when the sequence which matches |
641 | * the inner sequence can not proceed. |
642 | * |
643 | * This is useful for parsing things like attributes which don't match the standard expression |
644 | * parsers but are contained within the stop token. |
645 | */ |
646 | template <typename T> |
647 | Array<T> ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function<T()> parse, |
648 | std::function<bool()> before_stop = nullptr) { |
649 | VLOG(9) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep) |
650 | << " stop=" << ToString(stop); |
651 | Match(start); |
652 | |
653 | // This is for the empty arguments list case, if we have <start> <leftovers> <stop> token stream |
654 | // we must parse leftovers, then match a stop token. |
655 | if (before_stop) { |
656 | auto did_parse = before_stop(); |
657 | if (did_parse) { |
658 | Match(stop); |
659 | return {}; |
660 | } |
661 | } |
662 | |
663 | // This is the case in which we find an empty arguments lists and no leftovers. |
664 | if (WhenMatch(stop)) { |
665 | return Array<T>(); |
666 | } else { |
667 | VLOG(9) << "Parser::ParseSequence: parse first" ; |
668 | auto data = parse(); |
669 | Array<T> elements = {data}; |
670 | |
671 | if (WhenMatch(stop)) { |
672 | return elements; |
673 | // parse '( expr ',' * ')' |
674 | } else if (WhenMatch(sep)) { |
675 | while (true) { |
676 | VLOG(9) << "Parser::ParseSequence: parse element" ; |
677 | if (WhenMatch(stop)) { |
678 | break; |
679 | } else { |
680 | // If before stop is |
681 | if (before_stop) { |
682 | auto did_parse = before_stop(); |
683 | if (did_parse) { |
684 | Match(stop); |
685 | return elements; |
686 | } |
687 | } |
688 | auto data = parse(); |
689 | WhenMatch(sep); |
690 | elements.push_back(data); |
691 | } |
692 | } |
693 | return elements; |
694 | } else { |
695 | auto next = Peek(); |
696 | this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) |
697 | << "expected a " << Pretty(stop) << " found " |
698 | << Pretty(next->token_type)); |
699 | return Array<T>(nullptr); |
700 | } |
701 | } |
702 | } |
703 | |
704 | /*! \brief Parse a full IRModule. */ |
705 | IRModule ParseModule() { |
706 | // Parse the semver header at the top of the module. |
707 | this->version = ParseSemVer(); |
708 | // Parse the definitions. |
709 | auto defs = ParseDefinitions(); |
710 | // Parse the metadata section at the end. |
711 | auto metadata = ParseMetadata(); |
712 | |
713 | Match(TokenType::kEndOfFile); |
714 | |
715 | for (auto type_def : defs.types) { |
716 | module->AddTypeDef(type_def->header, type_def); |
717 | } |
718 | |
719 | for (auto func : defs.funcs) { |
720 | module->Add(func.global, func.function, true); |
721 | } |
722 | |
723 | return module; |
724 | } |
725 | |
726 | /*! \brief Parse the semantic versioning header. */ |
727 | SemVer ParseSemVer(bool required = true) { |
728 | if (Peek()->token_type == TokenType::kVersion) { |
729 | auto version = Match(TokenType::kVersion); |
730 | // TODO(@jroesch): we currently only support 0.0.5. |
731 | if (version.ToString() != "\"0.0.5\"" ) { |
732 | this->diag_ctx.Emit(Diagnostic::Error(version->span) |
733 | << "invalid semantic version `" << version.ToString() << "`" ); |
734 | } |
735 | } else if (required) { |
736 | this->diag_ctx.Emit(Diagnostic::Error(Peek()->span) |
737 | << "expected text format semantic version, found a " |
738 | << PrettyPrint(Peek())); |
739 | |
740 | this->diag_ctx.Emit(Diagnostic::Help(Peek()->span) |
741 | << "you can annotate it as #[version = \"0.0.5\"]" ); |
742 | } |
743 | return SemVer(0, 0, 5); |
744 | } |
745 | |
746 | /*! \brief Parse zero or more Relay definitions. */ |
747 | Definitions ParseDefinitions() { |
748 | Definitions defs; |
749 | |
750 | while (true) { |
751 | auto next = Peek(); |
752 | switch (next->token_type) { |
753 | case TokenType::kDefn: { |
754 | Consume(TokenType::kDefn); |
755 | auto global_tok = Match(TokenType::kGlobal); |
756 | auto global_name = global_tok.ToString(); |
757 | auto global = AddOrGet(&global_names, global_name); |
758 | auto func = WithSpan<relay::Function>([&]() { return ParseFunctionDef(); }); |
759 | ICHECK(func->span.defined()) << "spans must be set in parser" ; |
760 | defs.funcs.push_back(GlobalFunc(global, func)); |
761 | continue; |
762 | } |
763 | case TokenType::kTypeDef: { |
764 | defs.types.push_back(ParseTypeDef()); |
765 | continue; |
766 | } |
767 | case TokenType::kExtern: { |
768 | Consume(TokenType::kExtern); |
769 | auto type_def = ParseTypeDef(); |
770 | if (type_def->constructors.size()) { |
771 | diag_ctx.Emit(Diagnostic::Error(next->span) |
772 | << "an external type may not have any constructors" ); |
773 | } |
774 | defs.types.push_back(type_def); |
775 | } |
776 | default: |
777 | return defs; |
778 | } |
779 | } |
780 | } |
781 | |
782 | /*! \brief Parse zero or more Relay type definitions. */ |
783 | TypeData ParseTypeDef() { |
784 | // Match the `type` keyword. |
785 | Match(TokenType::kTypeDef); |
786 | // Parse the type's identifier. |
787 | auto type_tok = Match(TokenType::kIdentifier); |
788 | auto type_id = type_tok.ToString(); |
789 | auto type_global = AddOrGet(&type_names, type_id, TypeKind::kAdtHandle); |
790 | |
791 | Array<TypeVar> generics; |
792 | |
793 | bool should_pop = false; |
794 | if (Peek()->token_type == TokenType::kLSquare) { |
795 | // If we have generics we need to add a type scope. |
796 | PushTypeScope(); |
797 | should_pop = true; |
798 | generics = ParseSequence<TypeVar>( |
799 | TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { |
800 | auto type_var_name = Match(TokenType::kIdentifier).ToString(); |
801 | return BindTypeVar(type_var_name, TypeKind::kType); |
802 | }); |
803 | } |
804 | |
805 | Array<tvm::Constructor> ctors; |
806 | if (Peek()->token_type == TokenType::kLCurly) { |
807 | // Parse the list of constructors. |
808 | ctors = ParseSequence<tvm::Constructor>( |
809 | TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, [&]() { |
810 | // First match the name of the constructor. |
811 | auto ctor_tok = Match(TokenType::kIdentifier); |
812 | auto ctor_name = ctor_tok.ToString(); |
813 | |
814 | Constructor ctor; |
815 | // Match the optional field list. |
816 | if (Peek()->token_type != TokenType::kOpenParen) { |
817 | ctor = tvm::Constructor(ctor_name, {}, type_global); |
818 | } else { |
819 | auto arg_types = |
820 | ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma, |
821 | TokenType::kCloseParen, [&]() { return ParseType(); }); |
822 | ctor = tvm::Constructor(ctor_name, arg_types, type_global); |
823 | } |
824 | |
825 | ICHECK(ctor.defined()); |
826 | |
827 | try { |
828 | this->ctors.Add(ctor_name, ctor); |
829 | } catch (const DuplicateKeyError& e) { |
830 | this->diag_ctx.EmitFatal(Diagnostic::Error(ctor_tok->span) |
831 | << "a constructor with the name " |
832 | << "`" << ctor_name << "` " |
833 | << "was previously defined" ); |
834 | } |
835 | |
836 | return ctor; |
837 | }); |
838 | } |
839 | |
840 | // Now pop the type scope. |
841 | if (should_pop) { |
842 | PopTypeScopes(1); |
843 | } |
844 | |
845 | return TypeData(type_global, generics, ctors); |
846 | } |
847 | |
848 | std::string HackTokensAsString(int n) { |
849 | std::stringstream key; |
850 | n = std::min(static_cast<int>(tokens.size() - pos), n); |
851 | for (int i = 0; i < n; i++) { |
852 | key << ToString(tokens.at(pos + i)->token_type); |
853 | } |
854 | return key.str(); |
855 | } |
856 | |
857 | std::vector<Rule> ParseOp() { |
858 | std::vector<Rule> matched; |
859 | Peek(); |
860 | for (int i = 4; i > 0; i--) { |
861 | auto key = HackTokensAsString(i); |
862 | auto it = this->op_table.this_is_a_hack.find(key); |
863 | if (it != this->op_table.this_is_a_hack.end()) { |
864 | pos = pos + i; |
865 | matched.push_back(it->second); |
866 | } |
867 | } |
868 | |
869 | return matched; |
870 | } |
871 | |
872 | /*! \brief Parse a single Relay expression. */ |
873 | Expr ParseExpr() { |
874 | VLOG(9) << "Parser::ParseExpr" ; |
875 | return WithSpan<Expr>([this] { |
876 | std::vector<Expr> exprs; |
877 | |
878 | while (true) { |
879 | VLOG(9) << "Parser::ParseExpr: parsing a single expression" ; |
880 | auto next = Peek(); |
881 | switch (next->token_type) { |
882 | // For graph or let, match first rhs, then invoke ParseBindingExpr |
883 | // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue |
884 | case TokenType::kLCurly: { |
885 | // NB: Might need to optimize to remove deep recursion. |
886 | // Stack should only grow proportionally to the number of |
887 | // nested scopes. |
888 | // Parses `{` expression `}`. |
889 | auto block = WithSpan<Expr>([&]() { |
890 | return Bracket<Expr>(TokenType::kLCurly, TokenType::kRCurly, [&]() { |
891 | PushScope(); |
892 | auto expr = ParseExpr(); |
893 | PopScopes(1); |
894 | return expr; |
895 | }); |
896 | }); |
897 | exprs.push_back(block); |
898 | break; |
899 | } |
900 | case TokenType::kFreeVar: { |
901 | Consume(TokenType::kFreeVar); |
902 | auto var_token = Match(TokenType::kLocal); |
903 | |
904 | Type type; |
905 | if (WhenMatch(TokenType::kColon)) { |
906 | type = ParseType(); |
907 | } else { |
908 | type = IncompleteType(); |
909 | } |
910 | |
911 | BindFreeVar(var_token.ToString(), type); |
912 | break; |
913 | } |
914 | // Parses `let ...`; |
915 | case TokenType::kLet: |
916 | exprs.push_back(ParseBindingExpr()); |
917 | break; |
918 | case TokenType::kMatch: |
919 | case TokenType::kPartialMatch: { |
920 | bool is_total = next->token_type == TokenType::kMatch; |
921 | Consume(next->token_type); |
922 | exprs.push_back(ParseMatch(is_total)); |
923 | break; |
924 | } |
925 | |
926 | // %x ... |
927 | case TokenType::kGraph: |
928 | if (Lookahead(2)->token_type == TokenType::kEqual) { |
929 | exprs.push_back(ParseBindingExpr()); |
930 | break; |
931 | } |
932 | // intentional fall through here. |
933 | default: { |
934 | exprs.push_back(ParseExprBinOp()); |
935 | break; |
936 | } |
937 | } |
938 | |
939 | if (!WhenMatch(TokenType::kSemicolon)) { |
940 | break; |
941 | } |
942 | } |
943 | |
944 | ICHECK_GE(exprs.size(), 1); |
945 | |
946 | if (exprs.size() == 1) { |
947 | // ICHECK(exprs[0].defined() && exprs[0]->span.defined()) |
948 | // << "parser must set expression spans.\n" |
949 | // << exprs[0]; |
950 | return exprs[0]; |
951 | } else { |
952 | auto body = exprs.back(); |
953 | exprs.pop_back(); |
954 | while (exprs.size()) { |
955 | auto value = exprs.back(); |
956 | ICHECK(value->span.defined()) << "parser must set expression spans." ; |
957 | exprs.pop_back(); |
958 | body = relay::Let(Var("" , IncompleteType()), value, body, value->span.Merge(body->span)); |
959 | } |
960 | ICHECK(body->span.defined()) << "parser must set expression spans." ; |
961 | return body; |
962 | } |
963 | }); |
964 | } |
965 | |
966 | /*! \brief Parse a "binding expression"; an expression where |
967 | * a graph or let variable is bound. |
968 | * |
969 | * In order to avoid stack overflow this is implemented in a special |
970 | * iterative way to keep stack depth constant in a long chain of bindings. |
971 | */ |
972 | Expr ParseBindingExpr() { |
973 | // We use a loop here so that the stack depth |
974 | // does not grow linearly with a sequence of |
975 | // graph or let bindings. |
976 | // |
977 | // Assuming we start at call depth k, we will |
978 | // enter k + c call frames to parse the RHS |
979 | // of the bindings where `c` is the depth |
980 | // of recursion needed by RHS. |
981 | // |
982 | // If RHS is a call expresssion the c=1. |
983 | // |
984 | // Once we have parsed the RHS we will be |
985 | // back at depth K, and will return to |
986 | // this loop header to parse another |
987 | // graph or let binding. |
988 | // |
989 | // This ensures for n sequential bindings |
990 | // the call depth will be the same before |
991 | // and after parsing the n bindings. |
992 | VLOG(9) << "Parser::ParseBindingExpr" ; |
993 | std::vector<std::tuple<Var, Expr, Span>> bindings; |
994 | int scopes = 0; |
995 | |
996 | while (true) { |
997 | auto next = Peek(); |
998 | if (next->token_type == TokenType::kGraph && Lookahead(2)->token_type == TokenType::kEqual) { |
999 | Match(TokenType::kGraph); |
1000 | Match(TokenType::kEqual); |
1001 | auto val = this->ParseExprBinOp(); |
1002 | Match(TokenType::kSemicolon); |
1003 | AddGraphBinding(next, val); |
1004 | } else if (next->token_type == TokenType::kLet) { |
1005 | auto span = next->span; |
1006 | // Parse the 'let'. |
1007 | Consume(TokenType::kLet); |
1008 | |
1009 | // Parse the local '%<id>'. |
1010 | auto local_tok = Match(TokenType::kLocal); |
1011 | auto string = local_tok.ToString(); |
1012 | |
1013 | // Parse the optional type annotation (':' <type>). |
1014 | Type type; |
1015 | if (WhenMatch(TokenType::kColon)) { |
1016 | type = ParseType(); |
1017 | } |
1018 | |
1019 | auto var = BindVar(string, type); |
1020 | |
1021 | // Parse the '='; |
1022 | Match(TokenType::kEqual); |
1023 | |
1024 | // Parse the body, and the ';'. |
1025 | auto val = this->ParseExprBinOp(); |
1026 | Consume(TokenType::kSemicolon); |
1027 | |
1028 | // Add the bindings to the local data structure. |
1029 | std::tuple<relay::Var, relay::Expr, Span> tuple(var, val, span); |
1030 | bindings.push_back(tuple); |
1031 | scopes++; |
1032 | PushScope(); |
1033 | } else { |
1034 | // This is the only case we will increase the stack |
1035 | // depth. |
1036 | // |
1037 | // If we parse a program which is a sequence of N bindings |
1038 | // followed by a single body expression we will end up with |
1039 | // a call depth of 3, the first call to ParseExpr, then |
1040 | // ParseBindingExpr, then finally ParseExpr once more. |
1041 | |
1042 | auto body = this->ParseExpr(); |
1043 | |
1044 | // Remove the same number of scopes we added. |
1045 | PopScopes(scopes); |
1046 | |
1047 | if (bindings.size() == 0) { |
1048 | return body; |
1049 | } else { |
1050 | // We can now build the let binding up backwards. |
1051 | for (auto binding = bindings.rbegin(); binding != bindings.rend(); binding++) { |
1052 | auto span = body->span.Merge(std::get<2>(*binding)); |
1053 | body = relay::Let(std::get<0>(*binding), std::get<1>(*binding), body, span); |
1054 | } |
1055 | return body; |
1056 | } |
1057 | } |
1058 | } |
1059 | } |
1060 | |
1061 | /*! Parse a function definition without a leading keyword or identifier. |
1062 | * |
1063 | * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. |
1064 | */ |
1065 | Function ParseFunctionDef() { |
1066 | VLOG(9) << "Parser::ParseFunctionDef" ; |
1067 | return WithSpan<Function>([&]() { |
1068 | PushScope(); |
1069 | PushTypeScope(); |
1070 | |
1071 | Array<TypeVar> generics; |
1072 | if (Peek()->token_type == TokenType::kLSquare) { |
1073 | generics = ParseSequence<TypeVar>( |
1074 | TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { |
1075 | auto type_var_name = Match(TokenType::kIdentifier).ToString(); |
1076 | return BindTypeVar(type_var_name, TypeKind::kType); |
1077 | }); |
1078 | } |
1079 | |
1080 | Map<String, ObjectRef> raw_attrs; |
1081 | |
1082 | auto params = ParseSequence<Var>( |
1083 | TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, |
1084 | [&]() { |
1085 | auto token = Match(TokenType::kLocal); |
1086 | auto string = token.ToString(); |
1087 | |
1088 | // The fake attributes where the virtual device is specified. |
1089 | VirtualDevice virtual_device; |
1090 | if (WhenMatch(TokenType::kLCurly)) { |
1091 | Map<String, ObjectRef> fake_attrs = ParseAttrs(); |
1092 | VLOG(9) << "Fake attributes for function parameter: " << fake_attrs; |
1093 | Match(TokenType::kRCurly); |
1094 | if (fake_attrs.size() == 1 && fake_attrs.count(kVirtualDevice)) { |
1095 | ICHECK(fake_attrs[kVirtualDevice].as<VirtualDeviceNode>()) |
1096 | << "Expected the " << kVirtualDevice |
1097 | << " to have type VirtualDeviceNode, but got " << virtual_device->GetTypeKey(); |
1098 | virtual_device = Downcast<VirtualDevice>(fake_attrs[kVirtualDevice]); |
1099 | } |
1100 | } |
1101 | |
1102 | Type type; |
1103 | if (WhenMatch(TokenType::kColon)) { |
1104 | type = ParseType(); |
1105 | } |
1106 | return BindVar(string, type, virtual_device); |
1107 | }, |
1108 | [&] { |
1109 | auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; |
1110 | auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; |
1111 | |
1112 | if (is_ident && next_is_equal) { |
1113 | raw_attrs = ParseAttrs(); |
1114 | return true; |
1115 | } |
1116 | |
1117 | return false; |
1118 | }); |
1119 | |
1120 | Type ret_type; |
1121 | if (WhenMatch(TokenType::kMinus)) { |
1122 | Match(TokenType::kRAngle); |
1123 | ret_type = ParseType(); |
1124 | } |
1125 | |
1126 | auto body = Block<Expr>([&]() { return ParseExpr(); }); |
1127 | |
1128 | PopTypeScopes(1); |
1129 | PopScopes(1); |
1130 | |
1131 | // TODO(@jroesch): attributes should never be null, they should always be empty. |
1132 | if (raw_attrs.size()) { |
1133 | // Promote kVirtualDevice to first-class |
1134 | if (raw_attrs.count(kVirtualDevice)) { |
1135 | ObjectRef vid = raw_attrs.at(kVirtualDevice); |
1136 | ICHECK(vid.as<VirtualDeviceNode>()) |
1137 | << "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got " |
1138 | << vid->GetTypeKey(); |
1139 | |
1140 | DictAttrs attrs; |
1141 | // Don't fill the raw_attrs in if there's nothing other than kVirtualDevice in the |
1142 | // attributes |
1143 | if (raw_attrs.size() > 1) { |
1144 | raw_attrs.erase(kVirtualDevice); |
1145 | attrs = DictAttrs(raw_attrs); |
1146 | } |
1147 | Function func = relay::Function(params, body, ret_type, generics, attrs); |
1148 | func->virtual_device_ = vid; |
1149 | return func; |
1150 | } else { |
1151 | return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); |
1152 | } |
1153 | } else { |
1154 | return relay::Function(params, body, ret_type, generics, tvm::DictAttrs()); |
1155 | } |
1156 | }); |
1157 | } |
1158 | |
1159 | /*! \brief Parse an if-expression. */ |
1160 | Expr ParseIf() { |
1161 | return WithSpan<Expr>([&]() { |
1162 | VLOG(9) << "Parser::ParseIf" ; |
1163 | Consume(TokenType::kIf); |
1164 | |
1165 | auto guard = WithSpan<Expr>([&] { return Parens<Expr>([&] { return ParseExpr(); }); }); |
1166 | |
1167 | auto true_branch = Block<Expr>([&] { |
1168 | this->PushScope(); |
1169 | auto expr = ParseExpr(); |
1170 | this->PopScopes(1); |
1171 | return expr; |
1172 | }); |
1173 | |
1174 | Match(TokenType::kElse); |
1175 | |
1176 | auto false_branch = Block<Expr>([&] { |
1177 | this->PushScope(); |
1178 | auto expr = ParseExpr(); |
1179 | this->PopScopes(1); |
1180 | return expr; |
1181 | }); |
1182 | |
1183 | return relay::If(guard, true_branch, false_branch); |
1184 | }); |
1185 | } |
1186 | |
1187 | /* This factors parsing a list of patterns for both tuples, and constructors. */ |
1188 | Array<Pattern> ParsePatternList() { |
1189 | return ParseSequence<Pattern>(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, |
1190 | [&] { return ParsePattern(); }); |
1191 | } |
1192 | |
1193 | /*! \brief Parses a pattern for a match expression. |
1194 | * |
1195 | * A pattern is either a wildcard `_`, a local `%name`, |
1196 | * a constructor `C(p1, ..., pn)` or tuple `(p1, ..., pn). |
1197 | * |
1198 | * This function recursively parses a pattern. |
1199 | */ |
1200 | Pattern ParsePattern() { |
1201 | VLOG(9) << "Parser::ParsePattern" ; |
1202 | auto next = Peek(); |
1203 | switch (next->token_type) { |
1204 | case TokenType::kUnderscore: { |
1205 | Match(TokenType::kUnderscore); |
1206 | return PatternWildcard(); |
1207 | } |
1208 | case TokenType::kLocal: { |
1209 | auto id = Match(TokenType::kLocal); |
1210 | Type type_annotation; |
1211 | if (WhenMatch(TokenType::kColon)) { |
1212 | type_annotation = ParseType(); |
1213 | } |
1214 | auto var = BindVar(id.ToString(), type_annotation); |
1215 | return PatternVar(var); |
1216 | } |
1217 | case TokenType::kIdentifier: { |
1218 | auto id = Match(TokenType::kIdentifier); |
1219 | auto ctor = ctors.Get(id.ToString()); |
1220 | if (!ctor) { |
1221 | diag_ctx.EmitFatal( |
1222 | // TODO(@jroesch): split into error and help |
1223 | // deal with multiple rendering |
1224 | Diagnostic::Error(id->span) |
1225 | << "undefined constructor name `" << id.ToString() |
1226 | << "`, perhaps you intended to write a" |
1227 | << "pattern variable, considering changing this to `%" << id.ToString() << "`" ); |
1228 | } |
1229 | if (Peek()->token_type == TokenType::kOpenParen) { |
1230 | auto fields = ParsePatternList(); |
1231 | return PatternConstructor(ctor.value(), fields); |
1232 | } else { |
1233 | return PatternConstructor(ctor.value(), {}); |
1234 | } |
1235 | } |
1236 | default: |
1237 | return PatternTuple(ParsePatternList()); |
1238 | } |
1239 | } |
1240 | |
1241 | Clause ParseMatchArm() { |
1242 | PushScope(); |
1243 | auto pattern = ParsePattern(); |
1244 | Match(TokenType::kEqual); |
1245 | Consume(TokenType::kRAngle); |
1246 | auto expr = ParseExpr(); |
1247 | PopScopes(1); |
1248 | return Clause(pattern, expr); |
1249 | } |
1250 | |
1251 | Expr ParseMatch(bool is_total) { |
1252 | return WithSpan<Expr>([&]() { |
1253 | Expr scrutinee = ParseAtomicExpr(); |
1254 | |
1255 | Array<Clause> clauses = |
1256 | ParseSequence<Clause>(TokenType::kLCurly, TokenType::kComma, TokenType::kRCurly, |
1257 | [&] { return ParseMatchArm(); }); |
1258 | |
1259 | return relay::Match(scrutinee, clauses, is_total); |
1260 | }); |
1261 | } |
1262 | |
1263 | Expr ParseExprBinOp() { |
1264 | VLOG(9) << "Parser::ParseExprBinOp" ; |
1265 | return WithSpan<Expr>([this] { |
1266 | // We must parse at least one expression, the default |
1267 | // case is that there is no operator and we will fall |
1268 | // through. |
1269 | std::vector<Expr> exprs; |
1270 | Expr expr = WithSpan<Expr>([this] { return ParseCallExpr(); }); |
1271 | |
1272 | exprs.push_back(expr); |
1273 | |
1274 | // Now we parse an optional op. |
1275 | std::vector<Rule> ops; |
1276 | |
1277 | // We will now parse 0 or more operator occurrences. |
1278 | while (true) { |
1279 | auto opt_op = ParseOp(); |
1280 | |
1281 | // If we didn't parse one we done. |
1282 | if (opt_op.size() == 0) { |
1283 | break; |
1284 | } |
1285 | |
1286 | // Read the operation we parsed; |
1287 | auto op = opt_op[0]; |
1288 | |
1289 | Expr right = WithSpan<Expr>([this] { return ParseCallExpr(); }); |
1290 | ICHECK(right->span.defined()); |
1291 | |
1292 | // If the operator stack is empty |
1293 | // we parse an operator and expression |
1294 | // and push them to stacks, then |
1295 | // continue. |
1296 | if (ops.size() == 0) { |
1297 | ops.push_back(op); |
1298 | exprs.push_back(right); |
1299 | continue; |
1300 | } |
1301 | |
1302 | if (op.precedence > ops.back().precedence || |
1303 | (op.precedence == ops.back().precedence && op.left_assoc == false)) { |
1304 | ops.push_back(op); |
1305 | exprs.push_back(right); |
1306 | continue; |
1307 | } |
1308 | |
1309 | while (ops.size() && (op.precedence < ops.back().precedence || |
1310 | (op.precedence == ops.back().precedence && op.left_assoc == true))) { |
1311 | Rule new_op = ops.back(); |
1312 | ops.pop_back(); |
1313 | Expr right = exprs.back(); |
1314 | exprs.pop_back(); |
1315 | Expr left = exprs.back(); |
1316 | exprs.pop_back(); |
1317 | ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; |
1318 | exprs.push_back( |
1319 | relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); |
1320 | } |
1321 | |
1322 | exprs.push_back(right); |
1323 | ops.push_back(op); |
1324 | } |
1325 | |
1326 | while (ops.size()) { |
1327 | Rule new_op = ops.back(); |
1328 | ops.pop_back(); |
1329 | Expr right = exprs.back(); |
1330 | exprs.pop_back(); |
1331 | Expr left = exprs.back(); |
1332 | exprs.pop_back(); |
1333 | ICHECK(new_op.op.defined()) << "a call op must be set " << new_op.op; |
1334 | exprs.push_back( |
1335 | relay::Call(new_op.op, {left, right}, Attrs(), {}, left->span.Merge(right->span))); |
1336 | } |
1337 | |
1338 | ICHECK_EQ(ops.size(), 0) << "No operations should be left on the operation stack." ; |
1339 | |
1340 | ICHECK_EQ(exprs.size(), 1) |
1341 | << "Only a single expression should be left on the expression stack." ; |
1342 | |
1343 | return exprs[0]; |
1344 | }); |
1345 | } |
1346 | |
1347 | ObjectRef ParseAttributeValue() { |
1348 | VLOG(9) << "Parser::ParseAttributeValue" ; |
1349 | auto next = Peek(); |
1350 | switch (next->token_type) { |
1351 | case TokenType::kFloat: |
1352 | case TokenType::kInteger: |
1353 | case TokenType::kBoolean: |
1354 | case TokenType::kStringLiteral: |
1355 | return Match(next->token_type)->data; |
1356 | case TokenType::kMetaReference: |
1357 | return ParseMetaRef(); |
1358 | case TokenType::kLSquare: { |
1359 | return ParseSequence<ObjectRef>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, |
1360 | [&]() { return ParseAttributeValue(); }); |
1361 | } |
1362 | case TokenType::kOpenParen: { |
1363 | // TODO(@jroesch: need to figure out bracket vs. sequence) |
1364 | // return ParseSequence<ObjectRef>(TokenType::kOpenParen, TokenType::kComma, |
1365 | // TokenType::kCloseParen, |
1366 | // [&]() { return ParseAttributeValue(); }); |
1367 | return Bracket<ObjectRef>(TokenType::kOpenParen, TokenType::kCloseParen, |
1368 | [&]() { return ParseAttributeValue(); }); |
1369 | } |
1370 | // TODO(@jroesch): not sure about this being the right way to handle nulls. |
1371 | case TokenType::kIdentifier: { |
1372 | if (auto text = next->data.as<tvm::StringObj>()) { |
1373 | std::string id = GetRef<String>(text); |
1374 | if (id == "nullptr" ) { |
1375 | Match(TokenType::kIdentifier); |
1376 | return ObjectRef(); |
1377 | } |
1378 | if (id == "None" ) { |
1379 | Match(TokenType::kIdentifier); |
1380 | return Optional<ObjectRef>(); |
1381 | } |
1382 | } |
1383 | } |
1384 | default: |
1385 | return ParseAtomicExpr(); |
1386 | } |
1387 | } |
1388 | |
1389 | Map<String, ObjectRef> ParseAttrs() { |
1390 | VLOG(9) << "Parser::ParseAttrs" ; |
1391 | Map<String, ObjectRef> kwargs; |
1392 | while (Peek()->token_type == TokenType::kIdentifier) { |
1393 | auto key = GetHierarchicalName(ParseHierarchicalName().data); |
1394 | Match(TokenType::kEqual); |
1395 | // TOOD(@jroesch): syntactically what do we allow to appear in attribute right hand side. |
1396 | auto value = ParseAttributeValue(); |
1397 | // TODO(@jroesch): we need a robust way to handle this writing dtypes as strings in text |
1398 | // format is bad. |
1399 | kwargs.Set(key, value); |
1400 | WhenMatch(TokenType::kComma); |
1401 | } |
1402 | VLOG(9) << "Parser::ParseAttrs: kwargs=" << kwargs; |
1403 | return kwargs; |
1404 | } |
1405 | |
1406 | Expr ParseCallArgs(Expr op) { |
1407 | ICHECK(op.defined()) << "the operator must be defined" ; |
1408 | |
1409 | VLOG(9) << "Parser::ParseCallArgs" ; |
1410 | Attrs attrs; |
1411 | std::string op_key; |
1412 | bool is_op = false; |
1413 | |
1414 | if (auto op_node = op.as<OpNode>()) { |
1415 | is_op = true; |
1416 | op_key = op_node->attrs_type_key; |
1417 | } |
1418 | |
1419 | if (Peek()->token_type == TokenType::kOpenParen) { |
1420 | Array<Expr> args = ParseSequence<Expr>( |
1421 | TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, |
1422 | [&] { return ParseExpr(); }, |
1423 | [&] { |
1424 | auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; |
1425 | auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; |
1426 | auto is_pretty_attrs = is_ident && next_is_equal; |
1427 | auto is_meta_next = Lookahead(1)->token_type == TokenType::kMetaReference; |
1428 | // TODO(@jroesch): might not handle trailing comma |
1429 | auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen; |
1430 | auto is_meta_attrs = is_meta_next && last_meta; |
1431 | |
1432 | if (is_pretty_attrs || is_meta_attrs) { |
1433 | if (is_meta_attrs) { |
1434 | auto meta_ref = ParseMetaRef(); |
1435 | if (meta_ref.as<BaseAttrsNode>()) { |
1436 | attrs = Downcast<Attrs>(meta_ref); |
1437 | } else { |
1438 | // Not awesome parsing code here. |
1439 | this->pos--; |
1440 | return false; |
1441 | } |
1442 | } else { |
1443 | auto raw_attrs = ParseAttrs(); |
1444 | if (is_op && op_key.size()) { |
1445 | auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); |
1446 | ICHECK(attr_obj.defined()); |
1447 | attrs = Downcast<Attrs>(attr_obj); |
1448 | } else if (raw_attrs.count("attrs_type_key" )) { |
1449 | String attr_key = Downcast<String>(raw_attrs["attrs_type_key" ]); |
1450 | if (attr_key.size()) { |
1451 | raw_attrs.erase("attrs_type_key" ); |
1452 | auto attr_obj = |
1453 | tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); |
1454 | ICHECK(attr_obj.defined()); |
1455 | attrs = Downcast<Attrs>(attr_obj); |
1456 | } |
1457 | } else { |
1458 | this->diag_ctx.EmitFatal(Diagnostic::Error(op->span) |
1459 | << "unable to determine the 'attrs_type_key' with which " |
1460 | "to represent the call attributes for this operator" ); |
1461 | } |
1462 | } |
1463 | return true; |
1464 | } |
1465 | return false; |
1466 | }); |
1467 | |
1468 | if (!attrs.defined()) { |
1469 | if (is_op && op_key.size()) { |
1470 | auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, {}); |
1471 | ICHECK(attr_obj.defined()); |
1472 | attrs = Downcast<Attrs>(attr_obj); |
1473 | } |
1474 | } |
1475 | |
1476 | // TODO(@jroesch): in a secondary pass adjust spans. |
1477 | return Expr(Call(op, args, attrs, {})); |
1478 | } else { |
1479 | return Expr(); |
1480 | } |
1481 | |
1482 | return Expr(); |
1483 | } |
1484 | |
1485 | Expr ParseCallExpr() { |
1486 | VLOG(9) << "Parser::ParseCallExpr" ; |
1487 | return WithSpan<Expr>([this] { |
1488 | Expr expr = ParseAtomicExpr(); |
1489 | // Parse as many call args as possible, building up expression |
1490 | // |
1491 | // NB(@jroesch): this seems like a hack but in order to parse curried functions |
1492 | // and avoid complex grammar we will parse multiple call lists in a row. |
1493 | while (Peek()->token_type == TokenType::kOpenParen) { |
1494 | auto new_expr = ParseCallArgs(expr); |
1495 | |
1496 | if (new_expr.defined()) { |
1497 | expr = new_expr; |
1498 | } else { |
1499 | break; |
1500 | } |
1501 | } |
1502 | |
1503 | // We need a zero-arity case for constructors. |
1504 | if (auto ctor_node = expr.as<ConstructorNode>()) { |
1505 | if (ctor_node->inputs.size() == 0) { |
1506 | return Expr(Call(expr, {})); |
1507 | } |
1508 | } |
1509 | |
1510 | return expr; |
1511 | }); |
1512 | } |
1513 | |
1514 | Expr GetOp(const std::string& op_name, const Span& span) { |
1515 | VLOG(9) << "op_name=" << op_name << " span=" << span; |
1516 | try { |
1517 | return Op::Get(op_name); |
1518 | } catch (const Error& e) { |
1519 | // we can relax this, but probably need to relax checks or return non-null here. |
1520 | this->diag_ctx.EmitFatal(Diagnostic::Error(span) |
1521 | << "operator `" << op_name |
1522 | << "` not found, perhaps you forgot to register it?" ); |
1523 | return Expr(); |
1524 | } |
1525 | } |
1526 | |
1527 | Expr ParseAtomicExpr() { |
1528 | VLOG(9) << "Parser::ParseAtomicExpr" ; |
1529 | Expr expr = WithSpan<Expr>([this] { |
1530 | auto next = Peek(); |
1531 | switch (next->token_type) { |
1532 | case TokenType::kInteger: |
1533 | case TokenType::kFloat: { |
1534 | Consume(next->token_type); |
1535 | auto number = NumberToNDArray(next); |
1536 | Expr e = Constant(number, next->span); |
1537 | ICHECK(e->span.defined()) << "constant spans must be defined" ; |
1538 | return e; |
1539 | } |
1540 | case TokenType::kBoolean: { |
1541 | Consume(TokenType::kBoolean); |
1542 | int64_t value = Downcast<tvm::Integer>(next->data).IntValue(); |
1543 | Expr e = Constant(support::BoolToNDArray(value), next->span); |
1544 | ICHECK(e->span.defined()) << "constant spans must be defined" ; |
1545 | return e; |
1546 | } |
1547 | // Parse a local of the form `%x`. |
1548 | case TokenType::kLocal: { |
1549 | Consume(TokenType::kLocal); |
1550 | return Expr(LookupLocal(next)); |
1551 | } |
1552 | // Parse a local of the form `@x`. |
1553 | case TokenType::kGlobal: { |
1554 | auto global_name = next.ToString(); |
1555 | Consume(TokenType::kGlobal); |
1556 | auto global = AddOrGet(&global_names, global_name); |
1557 | return Expr(global); |
1558 | } |
1559 | // Parse a local of the form `x`. |
1560 | // Right now we fail to parse `x.y`. |
1561 | case TokenType::kIdentifier: { |
1562 | auto ctor = ctors.Get(next.ToString()); |
1563 | if (ctor) { |
1564 | Consume(TokenType::kIdentifier); |
1565 | return Expr(ctor.value()); |
1566 | } else { |
1567 | auto spanned_idents = ParseHierarchicalName(); |
1568 | auto idents = spanned_idents.data; |
1569 | auto span = spanned_idents.span; |
1570 | return GetOp(GetHierarchicalName(idents), span); |
1571 | } |
1572 | } |
1573 | case TokenType::kGraph: { |
1574 | Consume(TokenType::kGraph); |
1575 | return LookupGraphBinding(next); |
1576 | } |
1577 | case TokenType::kMetaReference: { |
1578 | return Downcast<Expr>(ParseMetaRef()); |
1579 | } |
1580 | case TokenType::kFn: { |
1581 | Consume(TokenType::kFn); |
1582 | Expr e = ParseFunctionDef(); |
1583 | ICHECK(e->span.defined()) << "function spans must be defined.\n" << e; |
1584 | return e; |
1585 | } |
1586 | case TokenType::kIf: { |
1587 | Expr e = ParseIf(); |
1588 | return e; |
1589 | } |
1590 | case TokenType::kRef: { |
1591 | Consume(TokenType::kRef); |
1592 | Match(TokenType::kOpenParen); |
1593 | auto ref_value = ParseExpr(); |
1594 | Match(TokenType::kCloseParen); |
1595 | return static_cast<Expr>(RefCreate(ref_value)); |
1596 | } |
1597 | case TokenType::kRefRead: { |
1598 | return WithSpan<Expr>([&]() { |
1599 | Consume(TokenType::kRefRead); |
1600 | Match(TokenType::kOpenParen); |
1601 | auto ref = ParseExpr(); |
1602 | Match(TokenType::kCloseParen); |
1603 | return static_cast<Expr>(RefRead(ref)); |
1604 | }); |
1605 | } |
1606 | case TokenType::kRefWrite: { |
1607 | return WithSpan<Expr>([&]() { |
1608 | Consume(TokenType::kRefWrite); |
1609 | Match(TokenType::kOpenParen); |
1610 | auto ref = ParseExpr(); |
1611 | Match(TokenType::kComma); |
1612 | auto value = ParseExpr(); |
1613 | Match(TokenType::kCloseParen); |
1614 | return static_cast<Expr>(RefWrite(ref, value)); |
1615 | }); |
1616 | } |
1617 | case TokenType::kOpenParen: { |
1618 | Span sp = next->span; |
1619 | Consume(TokenType::kOpenParen); |
1620 | // parse '(' ')' |
1621 | if (WhenMatch(TokenType::kCloseParen)) { |
1622 | return Expr(Tuple(Array<Expr>())); |
1623 | } else { |
1624 | Expr subexpr = ParseExpr(); |
1625 | // parse '(' expr ')' |
1626 | if (WhenMatch(TokenType::kCloseParen)) { |
1627 | return subexpr; |
1628 | // parse '( expr ',' * ')' |
1629 | } else if (WhenMatch(TokenType::kComma)) { |
1630 | Array<Expr> exprs = {subexpr}; |
1631 | while (true) { |
1632 | if (WhenMatch(TokenType::kCloseParen)) { |
1633 | break; |
1634 | } else { |
1635 | auto element = ParseExpr(); |
1636 | auto comma = Peek(); |
1637 | if (WhenMatch(TokenType::kComma)) { |
1638 | sp = sp.Merge(element->span.Merge(comma->span)); |
1639 | } else { |
1640 | sp = sp.Merge(element->span); |
1641 | } |
1642 | exprs.push_back(element); |
1643 | } |
1644 | } |
1645 | Expr tuple = Tuple(exprs, sp); |
1646 | ICHECK(tuple->span.defined()) << "tuple span should be defined" ; |
1647 | return tuple; |
1648 | } |
1649 | } |
1650 | } |
1651 | default: { |
1652 | this->diag_ctx.EmitFatal(Diagnostic::Error(next->span) |
1653 | << "expected an expression found " << Pretty(next->token_type)); |
1654 | return Expr(); |
1655 | } |
1656 | } |
1657 | }); |
1658 | |
1659 | if (WhenMatch(TokenType::kPeriod)) { |
1660 | auto token = Match(TokenType::kInteger); |
1661 | auto index = token.ToNumber(); |
1662 | auto span = token->span.Merge(expr->span); |
1663 | VLOG(9) << "Parser::ParseAtomicExpr: tuple get item" ; |
1664 | return relay::TupleGetItem(expr, index, span); |
1665 | } else { |
1666 | return expr; |
1667 | } |
1668 | } |
1669 | |
1670 | /*! \brief Parse a hierarchical name. |
1671 | * |
1672 | * The tokenizer produces a token stream of <id1> . <id2> |
1673 | * and so on for names of the form `nn.conv2d`. |
1674 | * Currently we only use string names everywhere instead |
1675 | * of a notion of a hierarchical name. |
1676 | * |
1677 | * The below utility reassembles a token stream into a |
1678 | * single stream inserting the required periods needed |
1679 | * to look up registered names. |
1680 | */ |
1681 | Spanned<Array<String>> ParseHierarchicalName() { |
1682 | Array<String> idents; |
1683 | Span span; |
1684 | while (Peek()->token_type == TokenType::kIdentifier) { |
1685 | auto token = Peek(); |
1686 | |
1687 | if (span.defined()) { |
1688 | span = span.Merge(token->span); |
1689 | } else { |
1690 | span = token->span; |
1691 | } |
1692 | |
1693 | auto name = token.ToString(); |
1694 | idents.push_back(name); |
1695 | Consume(TokenType::kIdentifier); |
1696 | |
1697 | // Keep parsing while we see a trailing period. |
1698 | if (Peek()->token_type == TokenType::kPeriod) { |
1699 | Consume(TokenType::kPeriod); |
1700 | continue; |
1701 | } else { |
1702 | // No more periods means we are done! |
1703 | break; |
1704 | } |
1705 | } |
1706 | |
1707 | return Spanned<Array<String>>(idents, span); |
1708 | } |
1709 | |
1710 | std::string GetHierarchicalName(Array<String> idents) { |
1711 | ICHECK_NE(idents.size(), 0); |
1712 | std::stringstream hierarchical_name; |
1713 | int i = 0; |
1714 | int periods = idents.size() - 1; |
1715 | for (auto ident : idents) { |
1716 | hierarchical_name << ident; |
1717 | if (i < periods) { |
1718 | hierarchical_name << "." ; |
1719 | i++; |
1720 | } |
1721 | } |
1722 | return hierarchical_name.str(); |
1723 | } |
1724 | |
1725 | /*! \brief Parse a shape. */ |
1726 | Array<tvm::PrimExpr> ParseShape() { |
1727 | auto dims = ParseSequence<tvm::PrimExpr>( |
1728 | TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() { |
1729 | tvm::PrimExpr dim; |
1730 | if (Peek()->token_type == TokenType::kMetaReference) { |
1731 | dim = Downcast<tvm::PrimExpr>(ParseMetaRef()); |
1732 | } else if (WhenMatch(TokenType::kQuestion)) { |
1733 | dim = tvm::tir::Any(); |
1734 | } else { |
1735 | dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data); |
1736 | } |
1737 | |
1738 | return dim; |
1739 | }); |
1740 | return dims; |
1741 | } |
1742 | |
1743 | /*! \brief Parse a function type. */ |
1744 | Type ParseFunctionType() { |
1745 | auto ty_params = ParseSequence<Type>(TokenType::kOpenParen, TokenType::kComma, |
1746 | TokenType::kCloseParen, [&]() { return ParseType(); }); |
1747 | |
1748 | Match(TokenType::kMinus); |
1749 | Match(TokenType::kRAngle); |
1750 | auto ret_type = ParseType(); |
1751 | |
1752 | return relay::FuncType(ty_params, ret_type, {}, {}); |
1753 | } |
1754 | |
1755 | // Parses a user defined ADT or type variable. |
1756 | Type ParseNonPrimitiveType(const Token& tok) { |
1757 | return WithSpan<Type>([&]() { |
1758 | auto name = tok.ToString(); |
1759 | Type head_type = LookupTypeVar(tok); |
1760 | |
1761 | if (!head_type.defined()) { |
1762 | // head_type = type_names.Get(name); |
1763 | head_type = AddOrGet(&type_names, name, TypeKind::kAdtHandle); |
1764 | } |
1765 | |
1766 | if (!head_type.defined()) { |
1767 | diag_ctx.EmitFatal(Diagnostic::Error(tok->span) |
1768 | << "the type constructor `" << name << "` is undefined" ); |
1769 | } |
1770 | |
1771 | Array<Type> arg_types; |
1772 | if (Peek()->token_type == TokenType::kLSquare) { |
1773 | arg_types = ParseSequence<Type>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, |
1774 | [&]() { return ParseType(); }); |
1775 | } |
1776 | |
1777 | if (arg_types.size()) { |
1778 | return static_cast<Type>(TypeCall(head_type, arg_types)); |
1779 | } else { |
1780 | if (head_type.as<GlobalTypeVarNode>()) { |
1781 | return static_cast<Type>(TypeCall(head_type, {})); |
1782 | } else { |
1783 | return static_cast<Type>(head_type); |
1784 | } |
1785 | } |
1786 | }); |
1787 | } |
1788 | |
1789 | /*! \brief Parses a TVM type. |
1790 | * |
1791 | * This matches either a `Tensor[shape, dtype]`, a user defined ADT, a tuple type, |
1792 | * a scalar type or an incomplete type `_`. |
1793 | */ |
1794 | Type ParseType() { |
1795 | return WithSpan<Type>([&]() -> Type { |
1796 | auto tok = Peek(); |
1797 | |
1798 | if (tok->token_type == TokenType::kOpenParen) { |
1799 | auto tys = |
1800 | ParseSequence<relay::Type>(TokenType::kOpenParen, TokenType::kComma, |
1801 | TokenType::kCloseParen, [&]() { return ParseType(); }); |
1802 | return relay::TupleType(tys); |
1803 | } else if (WhenMatch(TokenType::kFn)) { |
1804 | return ParseFunctionType(); |
1805 | } else if (WhenMatch(TokenType::kIdentifier)) { |
1806 | auto id = tok.ToString(); |
1807 | if (id == "Tensor" ) { |
1808 | Match(TokenType::kLSquare); |
1809 | auto shape = ParseShape(); |
1810 | Match(TokenType::kComma); |
1811 | auto dtype_tok = Match(TokenType::kIdentifier); |
1812 | auto dtype = DataType(String2DLDataType(dtype_tok.ToString())); |
1813 | Match(TokenType::kRSquare); |
1814 | return TensorType(shape, dtype); |
1815 | } else { |
1816 | auto ty = tok.ToString(); |
1817 | if (ty.rfind("int" , 0) == 0 || ty.find("float" , 0) == 0 || ty.find("uint" , 0) == 0 || |
1818 | ty.find("bool" , 0) == 0) { |
1819 | // Need to do better error handling here. |
1820 | auto dtype = DataType(String2DLDataType(tok.ToString())); |
1821 | return TensorType({}, dtype); |
1822 | } else { |
1823 | return ParseNonPrimitiveType(tok); |
1824 | } |
1825 | } |
1826 | } else if (WhenMatch(TokenType::kUnderscore)) { |
1827 | return IncompleteType(); |
1828 | } else { |
1829 | this->diag_ctx.EmitFatal(Diagnostic::Error(tok->span) |
1830 | << "failed to parse type found " << tok); |
1831 | return Type(); |
1832 | } |
1833 | }); |
1834 | } |
1835 | |
1836 | template <typename R> |
1837 | R ConsumeWhitespace(std::function<R()> func) { |
1838 | auto old = this->ignore_whitespace; |
1839 | this->ignore_whitespace = true; |
1840 | while (tokens[pos]->token_type == TokenType::kWhitespace) { |
1841 | pos++; |
1842 | } |
1843 | auto res = func(); |
1844 | this->ignore_whitespace = old; |
1845 | return res; |
1846 | } |
1847 | |
1848 | Map<String, Array<ObjectRef>> ParseMetadata() { |
1849 | if (Peek()->token_type == TokenType::kMetadata) { |
1850 | return Match(TokenType::kMetadata).ToMetadata(); |
1851 | } else { |
1852 | return Map<String, Array<ObjectRef>>(); |
1853 | } |
1854 | } |
1855 | |
1856 | /*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */ |
1857 | void DisplayNextN(int n) { |
1858 | std::cout << "remaining tokens: " << std::endl; |
1859 | auto bound = std::min(pos + n, static_cast<int>(tokens.size())); |
1860 | for (int i = 0; i < bound - pos; i++) { |
1861 | std::cout << tokens[pos + i] << std::endl; |
1862 | } |
1863 | } |
1864 | |
1865 | // A function for debugging the operator parser. |
1866 | void DebugStack(const std::vector<Expr>& exprs, const std::vector<Rule>& rules) { |
1867 | std::cout << "Expr Stack: " ; |
1868 | for (auto expr : exprs) { |
1869 | std::cout << expr << ", " ; |
1870 | } |
1871 | |
1872 | std::cout << std::endl; |
1873 | std::cout << "Op Stack: " ; |
1874 | for (auto rule : rules) { |
1875 | std::cout << rule.op << ", " ; |
1876 | } |
1877 | |
1878 | std::cout << std::endl; |
1879 | } |
1880 | }; |
1881 | |
1882 | Parser InitParser(const std::string& file_name, const std::string& file_content, |
1883 | const Optional<IRModule>& init_module, const MetaTable& init_meta_table) { |
1884 | VLOG(9) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size(); |
1885 | SourceName src_name = SourceName::Get(file_name); |
1886 | Source source(src_name, file_content); |
1887 | |
1888 | IRModule module; |
1889 | if (!init_module) { |
1890 | SourceMap source_map; |
1891 | module = IRModule({}, {}, {}, source_map); |
1892 | } else { |
1893 | module = init_module.value(); |
1894 | } |
1895 | |
1896 | module->source_map.Add(source); |
1897 | |
1898 | auto diag_ctx = DiagnosticContext::Default(module); |
1899 | auto tokens_and_table = Tokenize(diag_ctx, source); |
1900 | |
1901 | auto tokens = tokens_and_table.first; |
1902 | MetaTable meta_data_table = tokens_and_table.second.ToMetadata(); |
1903 | |
1904 | // Merge any entries in init_meta_table into anything captured in the #[metadata] section |
1905 | // of the file_content. Metadata references within file_content must use indexes which account |
1906 | // for this ordering. |
1907 | for (const auto& pair : init_meta_table) { |
1908 | Array<ObjectRef> items; |
1909 | if (meta_data_table.count(pair.first)) { |
1910 | items = meta_data_table[pair.first]; |
1911 | } |
1912 | for (const auto& obj : pair.second) { |
1913 | items.push_back(obj); |
1914 | } |
1915 | meta_data_table.Set(pair.first, items); |
1916 | } |
1917 | |
1918 | return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table)); |
1919 | } |
1920 | |
1921 | IRModule ParseModule(const std::string& file_name, const std::string& file_content, |
1922 | const Optional<IRModule>& init_module, const MetaTable& init_meta_table) { |
1923 | VLOG_CONTEXT << "ParseModule" ; |
1924 | VLOG(9) << "parsing and type-checking " << file_name; |
1925 | auto parser = InitParser(file_name, file_content, init_module, init_meta_table); |
1926 | auto mod = parser.ParseModule(); |
1927 | ICHECK(mod.defined()) << "The parser must return a non-null module." ; |
1928 | // NB(@jroesch): it is very important that we render any errors before we proceed |
1929 | // if there were any errors which allow the parser to proceed we must render them |
1930 | // here. |
1931 | parser.diag_ctx.Render(); |
1932 | auto infer_type = tvm::relay::transform::InferType(); |
1933 | ICHECK(infer_type.defined()) << "The type inferencer must be non-null." ; |
1934 | return infer_type(mod); |
1935 | } |
1936 | |
1937 | Expr ParseExpr(const std::string& file_name, const std::string& file_content) { |
1938 | VLOG(9) << "ParseExpr" ; |
1939 | auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable()); |
1940 | parser.ParseSemVer(false); |
1941 | parser.PushScope(); |
1942 | auto expr = parser.ParseExpr(); |
1943 | parser.Match(TokenType::kEndOfFile); |
1944 | // NB(@jroesch): it is very important that we render any errors before we proceed |
1945 | // if there were any errors which allow the parser to proceed we must render them |
1946 | // here. |
1947 | parser.diag_ctx.Render(); |
1948 | return expr; |
1949 | } |
1950 | |
1951 | TVM_REGISTER_GLOBAL("parser.ParseModuleInContext" ) |
1952 | .set_body_typed([](const std::string& file_name, const std::string& file_content, |
1953 | const Optional<IRModule>& init_module, const MetaTable& init_meta_table) { |
1954 | return ParseModule(file_name, file_content, init_module, init_meta_table); |
1955 | }); |
1956 | |
1957 | TVM_REGISTER_GLOBAL("parser.ParseModule" ) |
1958 | .set_body_typed([](const std::string& file_name, const std::string& file_content) { |
1959 | return ParseModule(file_name, file_content); |
1960 | }); |
1961 | |
1962 | TVM_REGISTER_GLOBAL("parser.ParseExpr" ) |
1963 | .set_body_typed([](tvm::String file_name, tvm::String file_content) { |
1964 | return ParseExpr(file_name, file_content); |
1965 | }); |
1966 | |
1967 | /*! |
1968 | * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources |
1969 | * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for |
1970 | * modules constructed programaticaly rather than textually. |
1971 | */ |
1972 | Pass AnnotateSpans() { |
1973 | auto pass_func = [](const IRModule& mod, const PassContext& ctx) { |
1974 | String text = AsText(mod, /*show_meta_data=*/true); |
1975 | VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text; |
1976 | return ParseModule("GeneratedSource" , text); |
1977 | }; |
1978 | return CreateModulePass(pass_func, 0, "AnnotateSpans" , {}); |
1979 | } |
1980 | |
1981 | TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans" ).set_body_typed(AnnotateSpans); |
1982 | |
1983 | } // namespace parser |
1984 | } // namespace tvm |
1985 | |