1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5// Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
6// by this parser is preliminary and may change.
7
8#include <stdexcept>
9#include <string>
10#include <unordered_map>
11
12#include "onnx/onnx_pb.h"
13#include "onnx/string_utils.h"
14
15#include "onnx/defs/parser.h"
16
17#define PARSE_TOKEN(x) CHECK_PARSER_STATUS(ParserBase::Parse(x))
18#define PARSE(...) CHECK_PARSER_STATUS(Parse(__VA_ARGS__))
19#define MATCH(...) CHECK_PARSER_STATUS(Match(__VA_ARGS__))
20
21namespace ONNX_NAMESPACE {
22
23Status ParserBase::Parse(Literal& result) {
24 bool decimal_point = false;
25 auto nextch = NextChar();
26 auto from = next_;
27 if (nextch == '"') {
28 ++next_;
29 bool has_escape = false;
30 while ((next_ < end_) && (*next_ != '"')) {
31 if (*next_ == '\\') {
32 has_escape = true;
33 ++next_;
34 if (next_ >= end_)
35 return ParseError("Incomplete string literal.");
36 }
37 ++next_;
38 }
39 if (next_ >= end_)
40 return ParseError("Incomplete string literal.");
41 ++next_;
42 result.type = LiteralType::STRING_LITERAL;
43 if (has_escape) {
44 std::string& target = result.value;
45 target.clear();
46 target.reserve(next_ - from - 2); // upper bound
47 // *from is the starting quote. *(next_-1) is the ending quote.
48 // Copy what is in-between, except for the escape character
49 while (++from < next_ - 1) {
50 // Copy current char, if not escape, or next char otherwise.
51 target.push_back(*from != '\\' ? (*from) : *(++from));
52 }
53 } else
54 result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes
55 } else if ((isdigit(nextch) || (nextch == '-'))) {
56 ++next_;
57
58 while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) {
59 if (*next_ == '.') {
60 if (decimal_point)
61 break; // Only one decimal point allowed in numeric literal
62 decimal_point = true;
63 }
64 ++next_;
65 }
66
67 if (next_ == from)
68 return ParseError("Value expected but not found.");
69
70 // Optional exponent syntax: (e|E)(+|-)?[0-9]+
71 if ((next_ < end_) && ((*next_ == 'e') || (*next_ == 'E'))) {
72 decimal_point = true; // treat as float-literal
73 ++next_;
74 if ((next_ < end_) && ((*next_ == '+') || (*next_ == '-')))
75 ++next_;
76 while ((next_ < end_) && (isdigit(*next_)))
77 ++next_;
78 }
79
80 result.value = std::string(from, next_ - from);
81 result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL;
82 }
83 return Status::OK();
84}
85
86Status OnnxParser::Parse(IdList& idlist) {
87 idlist.Clear();
88 std::string id;
89 ParseOptionalIdentifier(id);
90 if (id.empty())
91 return Status::OK(); // Treat as empty list of identifiers
92 *idlist.Add() = id;
93 while (Matches(',')) {
94 ParseOptionalIdentifier(id);
95 *idlist.Add() = id;
96 }
97 return Status::OK();
98}
99
100Status OnnxParser::Parse(char open, IdList& idlist, char close) {
101 idlist.Clear();
102 if (Matches(open)) {
103 PARSE(idlist);
104 MATCH(close);
105 }
106 return Status::OK();
107}
108
109Status OnnxParser::Parse(TensorShapeProto& shape) {
110 shape.clear_dim();
111 do {
112 if (Matches('?')) {
113 shape.add_dim();
114 } else {
115 // Check for a symbolic identifier ...
116 std::string id;
117 CHECK_PARSER_STATUS(ParseOptionalIdentifier(id));
118 if (!id.empty()) {
119 shape.add_dim()->set_dim_param(id);
120 } else {
121 // ...or a integer value
122 int64_t dimval = 0;
123 PARSE_TOKEN(dimval);
124 shape.add_dim()->set_dim_value(dimval);
125 }
126 }
127 } while (Matches(','));
128 return Status::OK();
129}
130
131Status OnnxParser::Parse(TypeProto& typeProto) {
132 std::string id;
133 CHECK_PARSER_STATUS(ParseIdentifier(id));
134 int dtype = PrimitiveTypeNameMap::Lookup(id);
135 if (dtype != 0) {
136 auto* tensortype = typeProto.mutable_tensor_type();
137 tensortype->set_elem_type(dtype);
138 tensortype->clear_shape();
139 // Grammar:
140 // float indicates scalar (rank 0)
141 // float [] indicates unknown rank tensor (not a zero rank tensor)
142 // float [one-or-more-dimensions] indicates tensor of known rank > 0.
143 if (Matches('[')) {
144 if (!Matches(']')) {
145 PARSE(*tensortype->mutable_shape());
146 MATCH(']');
147 }
148 } else {
149 // Create shape with zero dimensions for scalar
150 (void)(tensortype->mutable_shape());
151 }
152 } else {
153 switch (KeyWordMap::Lookup(id)) {
154 case KeyWordMap::KeyWord::SEQ_TYPE: {
155 // Grammar: seq ( type )
156 MATCH('(');
157 auto* seqtype = typeProto.mutable_sequence_type();
158 PARSE(*seqtype->mutable_elem_type());
159 MATCH(')');
160 break;
161 }
162 case KeyWordMap::KeyWord::MAP_TYPE: {
163 // Grammar: map ( prim-type , type )
164 MATCH('(');
165 auto* maptype = typeProto.mutable_map_type();
166 CHECK_PARSER_STATUS(ParseIdentifier(id));
167 dtype = PrimitiveTypeNameMap::Lookup(id);
168 if (dtype == 0) {
169 return ParseError("Expecting primitive type as map key type.");
170 }
171 maptype->set_key_type(dtype);
172 MATCH(',');
173 PARSE(*maptype->mutable_value_type());
174 MATCH(')');
175 break;
176 }
177 case KeyWordMap::KeyWord::OPTIONAL_TYPE: {
178 // Grammar: optional ( type )
179 MATCH('(');
180 auto* opttype = typeProto.mutable_optional_type();
181 PARSE(*opttype->mutable_elem_type());
182 MATCH(')');
183 break;
184 }
185 case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE: {
186 // Grammar: sparse_tensor ( tensor-type )
187 MATCH('(');
188 CHECK_PARSER_STATUS(ParseIdentifier(id));
189 dtype = PrimitiveTypeNameMap::Lookup(id);
190 if (dtype != 0) {
191 auto* sparsetype = typeProto.mutable_sparse_tensor_type();
192 sparsetype->set_elem_type(dtype);
193 sparsetype->clear_shape();
194 // Grammar:
195 // float indicates scalar (rank 0)
196 // float [] indicates unknown rank tensor (not a zero rank tensor)
197 // float [one-or-more-dimensions] indicates tensor of known rank > 0.
198 if (Matches('[')) {
199 if (!Matches(']')) {
200 PARSE(*sparsetype->mutable_shape());
201 MATCH(']');
202 }
203 } else {
204 // Create shape with zero dimensions for scalar
205 (void)(sparsetype->mutable_shape());
206 }
207 } else {
208 return ParseError("Unexpected type in sparse-tensor element type.");
209 }
210 MATCH(')');
211 break;
212 }
213 default:
214 return ParseError("Unexpected type.");
215 }
216 }
217 return Status::OK();
218}
219
220Status OnnxParser::Parse(ValueInfoProto& valueinfo) {
221 if (NextIsType())
222 PARSE(*valueinfo.mutable_type());
223 std::string name;
224 CHECK_PARSER_STATUS(ParseIdentifier(name));
225 valueinfo.set_name(name);
226 return Status::OK();
227}
228
229Status OnnxParser::Parse(ValueInfoList& vilist) {
230 vilist.Clear();
231 MATCH('(');
232 if (!Matches(')')) {
233 do {
234 PARSE(*vilist.Add());
235 } while (Matches(','));
236 MATCH(')');
237 }
238 return Status::OK();
239}
240
241// Each input element is a value-info with an optional initializer of the form "= initial-value".
242// The value-info is added to the "inputs", while the initializer is added to initializers.
243Status OnnxParser::ParseInput(ValueInfoList& inputs, TensorList& initializers) {
244 inputs.Clear();
245 if (Matches('(')) {
246 if (!Matches(')')) {
247 do {
248 ValueInfoProto vi;
249 PARSE(vi);
250 *inputs.Add() = vi;
251 if (Matches('=')) {
252 // default value for input
253 TensorProto& tp = *initializers.Add();
254 tp.set_name(vi.name());
255 CHECK_PARSER_STATUS(Parse(tp, vi.type()));
256 }
257 } while (Matches(','));
258 MATCH(')');
259 }
260 }
261 return Status::OK();
262}
263
264// This is handled slightly different from the inputs.
265// Each element is either a value-info or an initializer.
266// A value-info is added to the "value_infos", while an initializer is added to initializers.
267Status OnnxParser::ParseValueInfo(ValueInfoList& value_infos, TensorList& initializers) {
268 value_infos.Clear();
269 if (Matches('<')) {
270 if (!Matches('>')) {
271 do {
272 ValueInfoProto vi;
273 PARSE(vi);
274 if (Matches('=')) {
275 // initializer
276 TensorProto& tp = *initializers.Add();
277 tp.set_name(vi.name());
278 CHECK_PARSER_STATUS(Parse(tp, vi.type()));
279 } else {
280 // valueinfo
281 *value_infos.Add() = vi;
282 }
283 } while (Matches(','));
284 MATCH('>');
285 }
286 }
287 return Status::OK();
288}
289
290Status OnnxParser::Parse(TensorProto& tensorProto) {
291 tensorProto = TensorProto();
292 // Parse the concrete tensor-type with numeric dimensions:
293 TypeProto typeProto;
294 PARSE(typeProto);
295 ParseOptionalIdentifier(*tensorProto.mutable_name());
296 (void)Matches('='); // Optional, to unify handling of initializers as well as tensor-protos in other contexts
297 return Parse(tensorProto, typeProto);
298}
299
300// Parse TensorProto data given its type:
301Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto) {
302 if (!tensorTypeProto.has_tensor_type())
303 return ParseError("Error parsing TensorProto (expected a tensor type).");
304 auto elem_type = tensorTypeProto.tensor_type().elem_type();
305 tensorProto.set_data_type(elem_type);
306 if (!tensorTypeProto.tensor_type().has_shape())
307 return ParseError("Error parsing TensorProto (expected a tensor shape).");
308 uint64_t n = 1;
309 for (auto& dim : tensorTypeProto.tensor_type().shape().dim()) {
310 if (!dim.has_dim_value())
311 return ParseError("Error parsing TensorProto shape (expected numeric dimension).");
312 auto dimval = dim.dim_value();
313 tensorProto.add_dims(dimval);
314 n *= dimval;
315 }
316
317 // tensorProto.mutable_int64_data()->Reserve(n);
318 // Parse the actual values:
319
320 int64_t intval;
321 uint64_t uintval;
322 float floatval;
323 double dblval;
324 std::string strval;
325 MATCH('{');
326 if (!Matches('}')) {
327 do {
328 switch (static_cast<TensorProto::DataType>(elem_type)) {
329 case TensorProto::DataType::TensorProto_DataType_INT8:
330 case TensorProto::DataType::TensorProto_DataType_INT16:
331 case TensorProto::DataType::TensorProto_DataType_INT32:
332 case TensorProto::DataType::TensorProto_DataType_UINT8:
333 case TensorProto::DataType::TensorProto_DataType_UINT16:
334 case TensorProto::DataType::TensorProto_DataType_BOOL:
335 PARSE_TOKEN(intval);
336 // TODO: check values are in the correct range.
337 tensorProto.add_int32_data(intval);
338 break;
339 case TensorProto::DataType::TensorProto_DataType_INT64:
340 PARSE_TOKEN(intval);
341 tensorProto.add_int64_data(intval);
342 break;
343 case TensorProto::DataType::TensorProto_DataType_UINT32:
344 case TensorProto::DataType::TensorProto_DataType_UINT64:
345 PARSE_TOKEN(uintval);
346 tensorProto.add_uint64_data(uintval);
347 break;
348 case TensorProto::DataType::TensorProto_DataType_FLOAT:
349 PARSE_TOKEN(floatval);
350 tensorProto.add_float_data(floatval);
351 break;
352 case TensorProto::DataType::TensorProto_DataType_DOUBLE:
353 PARSE_TOKEN(dblval);
354 tensorProto.add_double_data(dblval);
355 break;
356 case TensorProto::DataType::TensorProto_DataType_STRING:
357 PARSE_TOKEN(strval);
358 tensorProto.add_string_data(strval);
359 break;
360 default:
361 return ParseError("Unhandled type: %d", elem_type);
362 }
363 } while (Matches(','));
364 MATCH('}');
365 }
366 return Status::OK();
367}
368
369bool OnnxParser::NextIsType() {
370 std::string id("");
371 (void)PeekIdentifier(id);
372 return (PrimitiveTypeNameMap::IsTypeName(id));
373}
374
375Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr) {
376 // Parse a single-value
377 auto next = NextChar();
378 if (isalpha(next) || next == '_') {
379 if (NextIsType()) {
380 attr.set_type(AttributeProto_AttributeType_TENSOR);
381 Parse(*attr.mutable_t());
382 } else {
383 attr.set_type(AttributeProto_AttributeType_GRAPH);
384 Parse(*attr.mutable_g());
385 }
386 } else if (Matches('@')) {
387 std::string name;
388 CHECK_PARSER_STATUS(ParseIdentifier(name));
389 attr.set_ref_attr_name(name);
390 } else {
391 Literal literal;
392 PARSE_TOKEN(literal);
393 switch (literal.type) {
394 case LiteralType::INT_LITERAL:
395 attr.set_type(AttributeProto_AttributeType_INT);
396 attr.set_i(std::stol(literal.value));
397 break;
398 case LiteralType::FLOAT_LITERAL:
399 attr.set_type(AttributeProto_AttributeType_FLOAT);
400 attr.set_f(static_cast<float>(std::stof(literal.value)));
401 break;
402 case LiteralType::STRING_LITERAL:
403 attr.set_type(AttributeProto_AttributeType_STRING);
404 attr.set_s(literal.value);
405 break;
406 default:
407 return ParseError("Unexpected literal type.");
408 }
409 }
410 return Status::OK();
411}
412
413Status OnnxParser::Parse(AttributeProto& attr) {
414 attr.Clear();
415 std::string name;
416 CHECK_PARSER_STATUS(ParseIdentifier(name));
417 attr.set_name(name);
418 if (Matches(':')) {
419 CHECK_PARSER_STATUS(ParseIdentifier(name));
420 int attrtype = AttributeTypeNameMap::Lookup(name);
421 if (attrtype != 0) {
422 attr.set_type(static_cast<AttributeProto_AttributeType>(attrtype));
423 } else {
424 return ParseError("Unexpected attribute type.");
425 }
426 }
427 MATCH('=');
428 if (NextChar() == '[') {
429 // Parse a list of values. For now, empty list is not allowed, as we need to
430 // figure out a type for the attribute.
431 std::vector<Literal> vals;
432 MATCH('[');
433 do {
434 AttributeProto nextval;
435 CHECK_PARSER_STATUS(ParseSingleAttributeValue(nextval));
436 switch (nextval.type()) {
437 case AttributeProto_AttributeType_INT:
438 attr.set_type(AttributeProto_AttributeType_INTS);
439 attr.add_ints(nextval.i());
440 break;
441 case AttributeProto_AttributeType_FLOAT:
442 attr.set_type(AttributeProto_AttributeType_FLOATS);
443 attr.add_floats(nextval.f());
444 break;
445 case AttributeProto_AttributeType_STRING:
446 attr.add_strings(nextval.s());
447 attr.set_type(AttributeProto_AttributeType_STRINGS);
448 break;
449 default:
450 break;
451 }
452 } while (Matches(','));
453 MATCH(']');
454 } else {
455 CHECK_PARSER_STATUS(ParseSingleAttributeValue(attr));
456 }
457 return Status::OK();
458}
459
460Status OnnxParser::Parse(AttrList& attrlist) {
461 attrlist.Clear();
462 if (Matches('<')) {
463 do {
464 PARSE(*attrlist.Add());
465 } while (Matches(','));
466 MATCH('>');
467 }
468 return Status::OK();
469}
470
471Status OnnxParser::Parse(NodeProto& node) {
472 PARSE(*node.mutable_output());
473 MATCH('=');
474 std::string domain("");
475 std::string id;
476 ParseIdentifier(id);
477 while (Matches('.')) {
478 if (!domain.empty())
479 domain += ".";
480 domain += id;
481 ParseIdentifier(id);
482 }
483 node.set_domain(domain);
484 node.set_op_type(id);
485 PARSE(*node.mutable_attribute());
486 MATCH('(');
487 PARSE(*node.mutable_input());
488 MATCH(')');
489 if (node.attribute_size() == 0) {
490 // Permit attributes to be specified before or after parameters.
491 PARSE(*node.mutable_attribute());
492 }
493 return Status::OK();
494}
495
496Status OnnxParser::Parse(NodeList& nodelist) {
497 nodelist.Clear();
498 MATCH('{');
499 while (!Matches('}')) {
500 PARSE(*nodelist.Add());
501 }
502 return Status::OK();
503}
504
505Status OnnxParser::Parse(GraphProto& graph) {
506 std::string id;
507 ParseIdentifier(id);
508 return Parse(id, graph);
509}
510
511Status OnnxParser::Parse(std::string name, GraphProto& graph) {
512 graph.set_name(name);
513 graph.mutable_initializer()->Clear();
514 CHECK_PARSER_STATUS(ParseInput(*graph.mutable_input(), *graph.mutable_initializer()));
515 MATCH('=');
516 MATCH('>', false);
517 PARSE(*graph.mutable_output());
518 CHECK_PARSER_STATUS(ParseValueInfo(*graph.mutable_value_info(), *graph.mutable_initializer()));
519 return Parse(*graph.mutable_node());
520}
521
522Status OnnxParser::Parse(FunctionProto& fn) {
523 fn.Clear();
524 std::string strval;
525 if (Matches('<')) {
526 do {
527 KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
528 PARSE_TOKEN(keyword);
529 MATCH(':');
530 switch (keyword) {
531 case KeyWordMap::KeyWord::OPSET_IMPORT:
532 PARSE(*fn.mutable_opset_import());
533 break;
534 case KeyWordMap::KeyWord::DOC_STRING:
535 PARSE_TOKEN(strval);
536 fn.set_doc_string(strval);
537 break;
538 case KeyWordMap::KeyWord::DOMAIN_KW:
539 PARSE_TOKEN(strval);
540 fn.set_domain(strval);
541 break;
542 default:
543 return ParseError("Unhandled keyword.");
544 }
545 } while (Matches(','));
546 MATCH('>');
547 }
548 std::string id;
549 ParseIdentifier(id);
550 fn.set_name(id);
551
552 PARSE('<', *fn.mutable_attribute(), '>');
553 PARSE('(', *fn.mutable_input(), ')');
554 MATCH('=');
555 MATCH('>', false);
556 PARSE('(', *fn.mutable_output(), ')');
557 return Parse(*fn.mutable_node());
558}
559
560Status OnnxParser::Parse(OpsetIdList& opsets) {
561 std::string strval;
562 int64_t intval = 0;
563 MATCH('[');
564 if (!Matches(']')) {
565 do {
566 auto* import = opsets.Add();
567 PARSE_TOKEN(strval);
568 import->set_domain(strval);
569 MATCH(':');
570 PARSE_TOKEN(intval);
571 import->set_version(intval);
572 } while (Matches(','));
573 MATCH(']');
574 }
575 return Status::OK();
576}
577
578Status OnnxParser::Parse(ModelProto& model) {
579 model.Clear();
580 std::string strval;
581 int64_t intval;
582 if (Matches('<')) {
583 do {
584 KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
585 PARSE_TOKEN(keyword);
586 MATCH(':');
587 switch (keyword) {
588 case KeyWordMap::KeyWord::IR_VERSION:
589 PARSE_TOKEN(intval);
590 model.set_ir_version(intval);
591 break;
592 case KeyWordMap::KeyWord::OPSET_IMPORT:
593 PARSE(*model.mutable_opset_import());
594 break;
595 case KeyWordMap::KeyWord::PRODUCER_NAME:
596 PARSE_TOKEN(strval);
597 model.set_producer_name(strval);
598 break;
599 case KeyWordMap::KeyWord::PRODUCER_VERSION:
600 PARSE_TOKEN(strval);
601 model.set_producer_version(strval);
602 break;
603 case KeyWordMap::KeyWord::DOMAIN_KW:
604 PARSE_TOKEN(strval);
605 model.set_domain(strval);
606 break;
607 case KeyWordMap::KeyWord::MODEL_VERSION:
608 PARSE_TOKEN(intval);
609 model.set_model_version(intval);
610 break;
611 case KeyWordMap::KeyWord::DOC_STRING:
612 PARSE_TOKEN(strval);
613 model.set_doc_string(strval);
614 break;
615 case KeyWordMap::KeyWord::METADATA_PROPS: {
616 auto& metadata_props = *model.mutable_metadata_props();
617 MATCH('[');
618 if (!Matches(']')) {
619 do {
620 auto* metadata = metadata_props.Add();
621 PARSE_TOKEN(strval);
622 metadata->set_key(strval);
623 MATCH(':');
624 PARSE_TOKEN(strval);
625 metadata->set_value(strval);
626 } while (Matches(','));
627 MATCH(']');
628 }
629 break;
630 }
631 default:
632 return ParseError("Unhandled keyword.");
633 }
634 } while (Matches(','));
635 MATCH('>');
636 }
637 PARSE(*model.mutable_graph());
638
639 auto* functions = model.mutable_functions();
640 while (!EndOfInput()) {
641 PARSE(*functions->Add());
642 }
643 return Status::OK();
644}
645
646} // namespace ONNX_NAMESPACE
647