1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // Experimental language syntax and parser for ONNX. Please note that the syntax as formalized |
6 | // by this parser is preliminary and may change. |
7 | |
8 | #include <stdexcept> |
9 | #include <string> |
10 | #include <unordered_map> |
11 | |
12 | #include "onnx/onnx_pb.h" |
13 | #include "onnx/string_utils.h" |
14 | |
15 | #include "onnx/defs/parser.h" |
16 | |
17 | #define PARSE_TOKEN(x) CHECK_PARSER_STATUS(ParserBase::Parse(x)) |
18 | #define PARSE(...) CHECK_PARSER_STATUS(Parse(__VA_ARGS__)) |
19 | #define MATCH(...) CHECK_PARSER_STATUS(Match(__VA_ARGS__)) |
20 | |
21 | namespace ONNX_NAMESPACE { |
22 | |
23 | Status ParserBase::Parse(Literal& result) { |
24 | bool decimal_point = false; |
25 | auto nextch = NextChar(); |
26 | auto from = next_; |
27 | if (nextch == '"') { |
28 | ++next_; |
29 | bool has_escape = false; |
30 | while ((next_ < end_) && (*next_ != '"')) { |
31 | if (*next_ == '\\') { |
32 | has_escape = true; |
33 | ++next_; |
34 | if (next_ >= end_) |
35 | return ParseError("Incomplete string literal." ); |
36 | } |
37 | ++next_; |
38 | } |
39 | if (next_ >= end_) |
40 | return ParseError("Incomplete string literal." ); |
41 | ++next_; |
42 | result.type = LiteralType::STRING_LITERAL; |
43 | if (has_escape) { |
44 | std::string& target = result.value; |
45 | target.clear(); |
46 | target.reserve(next_ - from - 2); // upper bound |
47 | // *from is the starting quote. *(next_-1) is the ending quote. |
48 | // Copy what is in-between, except for the escape character |
49 | while (++from < next_ - 1) { |
50 | // Copy current char, if not escape, or next char otherwise. |
51 | target.push_back(*from != '\\' ? (*from) : *(++from)); |
52 | } |
53 | } else |
54 | result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes |
55 | } else if ((isdigit(nextch) || (nextch == '-'))) { |
56 | ++next_; |
57 | |
58 | while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) { |
59 | if (*next_ == '.') { |
60 | if (decimal_point) |
61 | break; // Only one decimal point allowed in numeric literal |
62 | decimal_point = true; |
63 | } |
64 | ++next_; |
65 | } |
66 | |
67 | if (next_ == from) |
68 | return ParseError("Value expected but not found." ); |
69 | |
70 | // Optional exponent syntax: (e|E)(+|-)?[0-9]+ |
71 | if ((next_ < end_) && ((*next_ == 'e') || (*next_ == 'E'))) { |
72 | decimal_point = true; // treat as float-literal |
73 | ++next_; |
74 | if ((next_ < end_) && ((*next_ == '+') || (*next_ == '-'))) |
75 | ++next_; |
76 | while ((next_ < end_) && (isdigit(*next_))) |
77 | ++next_; |
78 | } |
79 | |
80 | result.value = std::string(from, next_ - from); |
81 | result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL; |
82 | } |
83 | return Status::OK(); |
84 | } |
85 | |
86 | Status OnnxParser::Parse(IdList& idlist) { |
87 | idlist.Clear(); |
88 | std::string id; |
89 | ParseOptionalIdentifier(id); |
90 | if (id.empty()) |
91 | return Status::OK(); // Treat as empty list of identifiers |
92 | *idlist.Add() = id; |
93 | while (Matches(',')) { |
94 | ParseOptionalIdentifier(id); |
95 | *idlist.Add() = id; |
96 | } |
97 | return Status::OK(); |
98 | } |
99 | |
100 | Status OnnxParser::Parse(char open, IdList& idlist, char close) { |
101 | idlist.Clear(); |
102 | if (Matches(open)) { |
103 | PARSE(idlist); |
104 | MATCH(close); |
105 | } |
106 | return Status::OK(); |
107 | } |
108 | |
109 | Status OnnxParser::Parse(TensorShapeProto& shape) { |
110 | shape.clear_dim(); |
111 | do { |
112 | if (Matches('?')) { |
113 | shape.add_dim(); |
114 | } else { |
115 | // Check for a symbolic identifier ... |
116 | std::string id; |
117 | CHECK_PARSER_STATUS(ParseOptionalIdentifier(id)); |
118 | if (!id.empty()) { |
119 | shape.add_dim()->set_dim_param(id); |
120 | } else { |
121 | // ...or a integer value |
122 | int64_t dimval = 0; |
123 | PARSE_TOKEN(dimval); |
124 | shape.add_dim()->set_dim_value(dimval); |
125 | } |
126 | } |
127 | } while (Matches(',')); |
128 | return Status::OK(); |
129 | } |
130 | |
131 | Status OnnxParser::Parse(TypeProto& typeProto) { |
132 | std::string id; |
133 | CHECK_PARSER_STATUS(ParseIdentifier(id)); |
134 | int dtype = PrimitiveTypeNameMap::Lookup(id); |
135 | if (dtype != 0) { |
136 | auto* tensortype = typeProto.mutable_tensor_type(); |
137 | tensortype->set_elem_type(dtype); |
138 | tensortype->clear_shape(); |
139 | // Grammar: |
140 | // float indicates scalar (rank 0) |
141 | // float [] indicates unknown rank tensor (not a zero rank tensor) |
142 | // float [one-or-more-dimensions] indicates tensor of known rank > 0. |
143 | if (Matches('[')) { |
144 | if (!Matches(']')) { |
145 | PARSE(*tensortype->mutable_shape()); |
146 | MATCH(']'); |
147 | } |
148 | } else { |
149 | // Create shape with zero dimensions for scalar |
150 | (void)(tensortype->mutable_shape()); |
151 | } |
152 | } else { |
153 | switch (KeyWordMap::Lookup(id)) { |
154 | case KeyWordMap::KeyWord::SEQ_TYPE: { |
155 | // Grammar: seq ( type ) |
156 | MATCH('('); |
157 | auto* seqtype = typeProto.mutable_sequence_type(); |
158 | PARSE(*seqtype->mutable_elem_type()); |
159 | MATCH(')'); |
160 | break; |
161 | } |
162 | case KeyWordMap::KeyWord::MAP_TYPE: { |
163 | // Grammar: map ( prim-type , type ) |
164 | MATCH('('); |
165 | auto* maptype = typeProto.mutable_map_type(); |
166 | CHECK_PARSER_STATUS(ParseIdentifier(id)); |
167 | dtype = PrimitiveTypeNameMap::Lookup(id); |
168 | if (dtype == 0) { |
169 | return ParseError("Expecting primitive type as map key type." ); |
170 | } |
171 | maptype->set_key_type(dtype); |
172 | MATCH(','); |
173 | PARSE(*maptype->mutable_value_type()); |
174 | MATCH(')'); |
175 | break; |
176 | } |
177 | case KeyWordMap::KeyWord::OPTIONAL_TYPE: { |
178 | // Grammar: optional ( type ) |
179 | MATCH('('); |
180 | auto* opttype = typeProto.mutable_optional_type(); |
181 | PARSE(*opttype->mutable_elem_type()); |
182 | MATCH(')'); |
183 | break; |
184 | } |
185 | case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE: { |
186 | // Grammar: sparse_tensor ( tensor-type ) |
187 | MATCH('('); |
188 | CHECK_PARSER_STATUS(ParseIdentifier(id)); |
189 | dtype = PrimitiveTypeNameMap::Lookup(id); |
190 | if (dtype != 0) { |
191 | auto* sparsetype = typeProto.mutable_sparse_tensor_type(); |
192 | sparsetype->set_elem_type(dtype); |
193 | sparsetype->clear_shape(); |
194 | // Grammar: |
195 | // float indicates scalar (rank 0) |
196 | // float [] indicates unknown rank tensor (not a zero rank tensor) |
197 | // float [one-or-more-dimensions] indicates tensor of known rank > 0. |
198 | if (Matches('[')) { |
199 | if (!Matches(']')) { |
200 | PARSE(*sparsetype->mutable_shape()); |
201 | MATCH(']'); |
202 | } |
203 | } else { |
204 | // Create shape with zero dimensions for scalar |
205 | (void)(sparsetype->mutable_shape()); |
206 | } |
207 | } else { |
208 | return ParseError("Unexpected type in sparse-tensor element type." ); |
209 | } |
210 | MATCH(')'); |
211 | break; |
212 | } |
213 | default: |
214 | return ParseError("Unexpected type." ); |
215 | } |
216 | } |
217 | return Status::OK(); |
218 | } |
219 | |
220 | Status OnnxParser::Parse(ValueInfoProto& valueinfo) { |
221 | if (NextIsType()) |
222 | PARSE(*valueinfo.mutable_type()); |
223 | std::string name; |
224 | CHECK_PARSER_STATUS(ParseIdentifier(name)); |
225 | valueinfo.set_name(name); |
226 | return Status::OK(); |
227 | } |
228 | |
229 | Status OnnxParser::Parse(ValueInfoList& vilist) { |
230 | vilist.Clear(); |
231 | MATCH('('); |
232 | if (!Matches(')')) { |
233 | do { |
234 | PARSE(*vilist.Add()); |
235 | } while (Matches(',')); |
236 | MATCH(')'); |
237 | } |
238 | return Status::OK(); |
239 | } |
240 | |
241 | // Each input element is a value-info with an optional initializer of the form "= initial-value". |
242 | // The value-info is added to the "inputs", while the initializer is added to initializers. |
243 | Status OnnxParser::ParseInput(ValueInfoList& inputs, TensorList& initializers) { |
244 | inputs.Clear(); |
245 | if (Matches('(')) { |
246 | if (!Matches(')')) { |
247 | do { |
248 | ValueInfoProto vi; |
249 | PARSE(vi); |
250 | *inputs.Add() = vi; |
251 | if (Matches('=')) { |
252 | // default value for input |
253 | TensorProto& tp = *initializers.Add(); |
254 | tp.set_name(vi.name()); |
255 | CHECK_PARSER_STATUS(Parse(tp, vi.type())); |
256 | } |
257 | } while (Matches(',')); |
258 | MATCH(')'); |
259 | } |
260 | } |
261 | return Status::OK(); |
262 | } |
263 | |
264 | // This is handled slightly different from the inputs. |
265 | // Each element is either a value-info or an initializer. |
266 | // A value-info is added to the "value_infos", while an initializer is added to initializers. |
267 | Status OnnxParser::ParseValueInfo(ValueInfoList& value_infos, TensorList& initializers) { |
268 | value_infos.Clear(); |
269 | if (Matches('<')) { |
270 | if (!Matches('>')) { |
271 | do { |
272 | ValueInfoProto vi; |
273 | PARSE(vi); |
274 | if (Matches('=')) { |
275 | // initializer |
276 | TensorProto& tp = *initializers.Add(); |
277 | tp.set_name(vi.name()); |
278 | CHECK_PARSER_STATUS(Parse(tp, vi.type())); |
279 | } else { |
280 | // valueinfo |
281 | *value_infos.Add() = vi; |
282 | } |
283 | } while (Matches(',')); |
284 | MATCH('>'); |
285 | } |
286 | } |
287 | return Status::OK(); |
288 | } |
289 | |
290 | Status OnnxParser::Parse(TensorProto& tensorProto) { |
291 | tensorProto = TensorProto(); |
292 | // Parse the concrete tensor-type with numeric dimensions: |
293 | TypeProto typeProto; |
294 | PARSE(typeProto); |
295 | ParseOptionalIdentifier(*tensorProto.mutable_name()); |
296 | (void)Matches('='); // Optional, to unify handling of initializers as well as tensor-protos in other contexts |
297 | return Parse(tensorProto, typeProto); |
298 | } |
299 | |
300 | // Parse TensorProto data given its type: |
301 | Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto) { |
302 | if (!tensorTypeProto.has_tensor_type()) |
303 | return ParseError("Error parsing TensorProto (expected a tensor type)." ); |
304 | auto elem_type = tensorTypeProto.tensor_type().elem_type(); |
305 | tensorProto.set_data_type(elem_type); |
306 | if (!tensorTypeProto.tensor_type().has_shape()) |
307 | return ParseError("Error parsing TensorProto (expected a tensor shape)." ); |
308 | uint64_t n = 1; |
309 | for (auto& dim : tensorTypeProto.tensor_type().shape().dim()) { |
310 | if (!dim.has_dim_value()) |
311 | return ParseError("Error parsing TensorProto shape (expected numeric dimension)." ); |
312 | auto dimval = dim.dim_value(); |
313 | tensorProto.add_dims(dimval); |
314 | n *= dimval; |
315 | } |
316 | |
317 | // tensorProto.mutable_int64_data()->Reserve(n); |
318 | // Parse the actual values: |
319 | |
320 | int64_t intval; |
321 | uint64_t uintval; |
322 | float floatval; |
323 | double dblval; |
324 | std::string strval; |
325 | MATCH('{'); |
326 | if (!Matches('}')) { |
327 | do { |
328 | switch (static_cast<TensorProto::DataType>(elem_type)) { |
329 | case TensorProto::DataType::TensorProto_DataType_INT8: |
330 | case TensorProto::DataType::TensorProto_DataType_INT16: |
331 | case TensorProto::DataType::TensorProto_DataType_INT32: |
332 | case TensorProto::DataType::TensorProto_DataType_UINT8: |
333 | case TensorProto::DataType::TensorProto_DataType_UINT16: |
334 | case TensorProto::DataType::TensorProto_DataType_BOOL: |
335 | PARSE_TOKEN(intval); |
336 | // TODO: check values are in the correct range. |
337 | tensorProto.add_int32_data(intval); |
338 | break; |
339 | case TensorProto::DataType::TensorProto_DataType_INT64: |
340 | PARSE_TOKEN(intval); |
341 | tensorProto.add_int64_data(intval); |
342 | break; |
343 | case TensorProto::DataType::TensorProto_DataType_UINT32: |
344 | case TensorProto::DataType::TensorProto_DataType_UINT64: |
345 | PARSE_TOKEN(uintval); |
346 | tensorProto.add_uint64_data(uintval); |
347 | break; |
348 | case TensorProto::DataType::TensorProto_DataType_FLOAT: |
349 | PARSE_TOKEN(floatval); |
350 | tensorProto.add_float_data(floatval); |
351 | break; |
352 | case TensorProto::DataType::TensorProto_DataType_DOUBLE: |
353 | PARSE_TOKEN(dblval); |
354 | tensorProto.add_double_data(dblval); |
355 | break; |
356 | case TensorProto::DataType::TensorProto_DataType_STRING: |
357 | PARSE_TOKEN(strval); |
358 | tensorProto.add_string_data(strval); |
359 | break; |
360 | default: |
361 | return ParseError("Unhandled type: %d" , elem_type); |
362 | } |
363 | } while (Matches(',')); |
364 | MATCH('}'); |
365 | } |
366 | return Status::OK(); |
367 | } |
368 | |
369 | bool OnnxParser::NextIsType() { |
370 | std::string id("" ); |
371 | (void)PeekIdentifier(id); |
372 | return (PrimitiveTypeNameMap::IsTypeName(id)); |
373 | } |
374 | |
375 | Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr) { |
376 | // Parse a single-value |
377 | auto next = NextChar(); |
378 | if (isalpha(next) || next == '_') { |
379 | if (NextIsType()) { |
380 | attr.set_type(AttributeProto_AttributeType_TENSOR); |
381 | Parse(*attr.mutable_t()); |
382 | } else { |
383 | attr.set_type(AttributeProto_AttributeType_GRAPH); |
384 | Parse(*attr.mutable_g()); |
385 | } |
386 | } else if (Matches('@')) { |
387 | std::string name; |
388 | CHECK_PARSER_STATUS(ParseIdentifier(name)); |
389 | attr.set_ref_attr_name(name); |
390 | } else { |
391 | Literal literal; |
392 | PARSE_TOKEN(literal); |
393 | switch (literal.type) { |
394 | case LiteralType::INT_LITERAL: |
395 | attr.set_type(AttributeProto_AttributeType_INT); |
396 | attr.set_i(std::stol(literal.value)); |
397 | break; |
398 | case LiteralType::FLOAT_LITERAL: |
399 | attr.set_type(AttributeProto_AttributeType_FLOAT); |
400 | attr.set_f(static_cast<float>(std::stof(literal.value))); |
401 | break; |
402 | case LiteralType::STRING_LITERAL: |
403 | attr.set_type(AttributeProto_AttributeType_STRING); |
404 | attr.set_s(literal.value); |
405 | break; |
406 | default: |
407 | return ParseError("Unexpected literal type." ); |
408 | } |
409 | } |
410 | return Status::OK(); |
411 | } |
412 | |
413 | Status OnnxParser::Parse(AttributeProto& attr) { |
414 | attr.Clear(); |
415 | std::string name; |
416 | CHECK_PARSER_STATUS(ParseIdentifier(name)); |
417 | attr.set_name(name); |
418 | if (Matches(':')) { |
419 | CHECK_PARSER_STATUS(ParseIdentifier(name)); |
420 | int attrtype = AttributeTypeNameMap::Lookup(name); |
421 | if (attrtype != 0) { |
422 | attr.set_type(static_cast<AttributeProto_AttributeType>(attrtype)); |
423 | } else { |
424 | return ParseError("Unexpected attribute type." ); |
425 | } |
426 | } |
427 | MATCH('='); |
428 | if (NextChar() == '[') { |
429 | // Parse a list of values. For now, empty list is not allowed, as we need to |
430 | // figure out a type for the attribute. |
431 | std::vector<Literal> vals; |
432 | MATCH('['); |
433 | do { |
434 | AttributeProto nextval; |
435 | CHECK_PARSER_STATUS(ParseSingleAttributeValue(nextval)); |
436 | switch (nextval.type()) { |
437 | case AttributeProto_AttributeType_INT: |
438 | attr.set_type(AttributeProto_AttributeType_INTS); |
439 | attr.add_ints(nextval.i()); |
440 | break; |
441 | case AttributeProto_AttributeType_FLOAT: |
442 | attr.set_type(AttributeProto_AttributeType_FLOATS); |
443 | attr.add_floats(nextval.f()); |
444 | break; |
445 | case AttributeProto_AttributeType_STRING: |
446 | attr.add_strings(nextval.s()); |
447 | attr.set_type(AttributeProto_AttributeType_STRINGS); |
448 | break; |
449 | default: |
450 | break; |
451 | } |
452 | } while (Matches(',')); |
453 | MATCH(']'); |
454 | } else { |
455 | CHECK_PARSER_STATUS(ParseSingleAttributeValue(attr)); |
456 | } |
457 | return Status::OK(); |
458 | } |
459 | |
460 | Status OnnxParser::Parse(AttrList& attrlist) { |
461 | attrlist.Clear(); |
462 | if (Matches('<')) { |
463 | do { |
464 | PARSE(*attrlist.Add()); |
465 | } while (Matches(',')); |
466 | MATCH('>'); |
467 | } |
468 | return Status::OK(); |
469 | } |
470 | |
471 | Status OnnxParser::Parse(NodeProto& node) { |
472 | PARSE(*node.mutable_output()); |
473 | MATCH('='); |
474 | std::string domain("" ); |
475 | std::string id; |
476 | ParseIdentifier(id); |
477 | while (Matches('.')) { |
478 | if (!domain.empty()) |
479 | domain += "." ; |
480 | domain += id; |
481 | ParseIdentifier(id); |
482 | } |
483 | node.set_domain(domain); |
484 | node.set_op_type(id); |
485 | PARSE(*node.mutable_attribute()); |
486 | MATCH('('); |
487 | PARSE(*node.mutable_input()); |
488 | MATCH(')'); |
489 | if (node.attribute_size() == 0) { |
490 | // Permit attributes to be specified before or after parameters. |
491 | PARSE(*node.mutable_attribute()); |
492 | } |
493 | return Status::OK(); |
494 | } |
495 | |
496 | Status OnnxParser::Parse(NodeList& nodelist) { |
497 | nodelist.Clear(); |
498 | MATCH('{'); |
499 | while (!Matches('}')) { |
500 | PARSE(*nodelist.Add()); |
501 | } |
502 | return Status::OK(); |
503 | } |
504 | |
505 | Status OnnxParser::Parse(GraphProto& graph) { |
506 | std::string id; |
507 | ParseIdentifier(id); |
508 | return Parse(id, graph); |
509 | } |
510 | |
511 | Status OnnxParser::Parse(std::string name, GraphProto& graph) { |
512 | graph.set_name(name); |
513 | graph.mutable_initializer()->Clear(); |
514 | CHECK_PARSER_STATUS(ParseInput(*graph.mutable_input(), *graph.mutable_initializer())); |
515 | MATCH('='); |
516 | MATCH('>', false); |
517 | PARSE(*graph.mutable_output()); |
518 | CHECK_PARSER_STATUS(ParseValueInfo(*graph.mutable_value_info(), *graph.mutable_initializer())); |
519 | return Parse(*graph.mutable_node()); |
520 | } |
521 | |
522 | Status OnnxParser::Parse(FunctionProto& fn) { |
523 | fn.Clear(); |
524 | std::string strval; |
525 | if (Matches('<')) { |
526 | do { |
527 | KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE; |
528 | PARSE_TOKEN(keyword); |
529 | MATCH(':'); |
530 | switch (keyword) { |
531 | case KeyWordMap::KeyWord::OPSET_IMPORT: |
532 | PARSE(*fn.mutable_opset_import()); |
533 | break; |
534 | case KeyWordMap::KeyWord::DOC_STRING: |
535 | PARSE_TOKEN(strval); |
536 | fn.set_doc_string(strval); |
537 | break; |
538 | case KeyWordMap::KeyWord::DOMAIN_KW: |
539 | PARSE_TOKEN(strval); |
540 | fn.set_domain(strval); |
541 | break; |
542 | default: |
543 | return ParseError("Unhandled keyword." ); |
544 | } |
545 | } while (Matches(',')); |
546 | MATCH('>'); |
547 | } |
548 | std::string id; |
549 | ParseIdentifier(id); |
550 | fn.set_name(id); |
551 | |
552 | PARSE('<', *fn.mutable_attribute(), '>'); |
553 | PARSE('(', *fn.mutable_input(), ')'); |
554 | MATCH('='); |
555 | MATCH('>', false); |
556 | PARSE('(', *fn.mutable_output(), ')'); |
557 | return Parse(*fn.mutable_node()); |
558 | } |
559 | |
560 | Status OnnxParser::Parse(OpsetIdList& opsets) { |
561 | std::string strval; |
562 | int64_t intval = 0; |
563 | MATCH('['); |
564 | if (!Matches(']')) { |
565 | do { |
566 | auto* import = opsets.Add(); |
567 | PARSE_TOKEN(strval); |
568 | import->set_domain(strval); |
569 | MATCH(':'); |
570 | PARSE_TOKEN(intval); |
571 | import->set_version(intval); |
572 | } while (Matches(',')); |
573 | MATCH(']'); |
574 | } |
575 | return Status::OK(); |
576 | } |
577 | |
578 | Status OnnxParser::Parse(ModelProto& model) { |
579 | model.Clear(); |
580 | std::string strval; |
581 | int64_t intval; |
582 | if (Matches('<')) { |
583 | do { |
584 | KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE; |
585 | PARSE_TOKEN(keyword); |
586 | MATCH(':'); |
587 | switch (keyword) { |
588 | case KeyWordMap::KeyWord::IR_VERSION: |
589 | PARSE_TOKEN(intval); |
590 | model.set_ir_version(intval); |
591 | break; |
592 | case KeyWordMap::KeyWord::OPSET_IMPORT: |
593 | PARSE(*model.mutable_opset_import()); |
594 | break; |
595 | case KeyWordMap::KeyWord::PRODUCER_NAME: |
596 | PARSE_TOKEN(strval); |
597 | model.set_producer_name(strval); |
598 | break; |
599 | case KeyWordMap::KeyWord::PRODUCER_VERSION: |
600 | PARSE_TOKEN(strval); |
601 | model.set_producer_version(strval); |
602 | break; |
603 | case KeyWordMap::KeyWord::DOMAIN_KW: |
604 | PARSE_TOKEN(strval); |
605 | model.set_domain(strval); |
606 | break; |
607 | case KeyWordMap::KeyWord::MODEL_VERSION: |
608 | PARSE_TOKEN(intval); |
609 | model.set_model_version(intval); |
610 | break; |
611 | case KeyWordMap::KeyWord::DOC_STRING: |
612 | PARSE_TOKEN(strval); |
613 | model.set_doc_string(strval); |
614 | break; |
615 | case KeyWordMap::KeyWord::METADATA_PROPS: { |
616 | auto& metadata_props = *model.mutable_metadata_props(); |
617 | MATCH('['); |
618 | if (!Matches(']')) { |
619 | do { |
620 | auto* metadata = metadata_props.Add(); |
621 | PARSE_TOKEN(strval); |
622 | metadata->set_key(strval); |
623 | MATCH(':'); |
624 | PARSE_TOKEN(strval); |
625 | metadata->set_value(strval); |
626 | } while (Matches(',')); |
627 | MATCH(']'); |
628 | } |
629 | break; |
630 | } |
631 | default: |
632 | return ParseError("Unhandled keyword." ); |
633 | } |
634 | } while (Matches(',')); |
635 | MATCH('>'); |
636 | } |
637 | PARSE(*model.mutable_graph()); |
638 | |
639 | auto* functions = model.mutable_functions(); |
640 | while (!EndOfInput()) { |
641 | PARSE(*functions->Add()); |
642 | } |
643 | return Status::OK(); |
644 | } |
645 | |
646 | } // namespace ONNX_NAMESPACE |
647 | |