1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5// Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
6// by this parser is preliminary and may change.
7
8#pragma once
9
10#include <ctype.h>
11#include <iostream>
12#include <stdexcept>
13#include <string>
14#include <unordered_map>
15
16#include "onnx/onnx_pb.h"
17
18#include "onnx/common/status.h"
19#include "onnx/string_utils.h"
20
21namespace ONNX_NAMESPACE {
22
23using namespace ONNX_NAMESPACE::Common;
24
25using IdList = google::protobuf::RepeatedPtrField<std::string>;
26
27using NodeList = google::protobuf::RepeatedPtrField<NodeProto>;
28
29using AttrList = google::protobuf::RepeatedPtrField<AttributeProto>;
30
31using ValueInfoList = google::protobuf::RepeatedPtrField<ValueInfoProto>;
32
33using TensorList = google::protobuf::RepeatedPtrField<TensorProto>;
34
35using OpsetIdList = google::protobuf::RepeatedPtrField<OperatorSetIdProto>;
36
37#define CHECK_PARSER_STATUS(status) \
38 { \
39 auto local_status_ = status; \
40 if (!local_status_.IsOK()) \
41 return local_status_; \
42 }
43
44template <typename Map>
45class StringIntMap {
46 public:
47 static const std::unordered_map<std::string, int32_t>& Instance() {
48 static Map instance;
49 return instance.map_;
50 }
51
52 static int32_t Lookup(const std::string& dtype) {
53 auto it = Instance().find(dtype);
54 if (it != Instance().end())
55 return it->second;
56 return 0;
57 }
58
59 static const std::string& ToString(int32_t dtype) {
60 static std::string undefined("undefined");
61 for (const auto& pair : Instance()) {
62 if (pair.second == dtype)
63 return pair.first;
64 }
65 return undefined;
66 }
67
68 protected:
69 std::unordered_map<std::string, int32_t> map_;
70};
71
72class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
73 public:
74 PrimitiveTypeNameMap() : StringIntMap() {
75 map_["float"] = 1;
76 map_["uint8"] = 2;
77 map_["int8"] = 3;
78 map_["uint16"] = 4;
79 map_["int16"] = 5;
80 map_["int32"] = 6;
81 map_["int64"] = 7;
82 map_["string"] = 8;
83 map_["bool"] = 9;
84 map_["float16"] = 10;
85 map_["double"] = 11;
86 map_["uint32"] = 12;
87 map_["uint64"] = 13;
88 map_["complex64"] = 14;
89 map_["complex128"] = 15;
90 map_["bfloat16"] = 16;
91 }
92
93 static bool IsTypeName(const std::string& dtype) {
94 return Lookup(dtype) != 0;
95 }
96};
97
98class AttributeTypeNameMap : public StringIntMap<AttributeTypeNameMap> {
99 public:
100 AttributeTypeNameMap() : StringIntMap() {
101 map_["float"] = 1;
102 map_["int"] = 2;
103 map_["string"] = 3;
104 map_["tensor"] = 4;
105 map_["graph"] = 5;
106 map_["sparse_tensor"] = 11;
107 map_["type_proto"] = 13;
108 map_["floats"] = 6;
109 map_["ints"] = 7;
110 map_["strings"] = 8;
111 map_["tensors"] = 9;
112 map_["graphs"] = 10;
113 map_["sparse_tensors"] = 12;
114 map_["type_protos"] = 14;
115 }
116};
117
118class KeyWordMap {
119 public:
120 enum class KeyWord {
121 NONE,
122 IR_VERSION,
123 OPSET_IMPORT,
124 PRODUCER_NAME,
125 PRODUCER_VERSION,
126 DOMAIN_KW,
127 MODEL_VERSION,
128 DOC_STRING,
129 METADATA_PROPS,
130 SEQ_TYPE,
131 MAP_TYPE,
132 OPTIONAL_TYPE,
133 SPARSE_TENSOR_TYPE
134 };
135
136 KeyWordMap() {
137 map_["ir_version"] = KeyWord::IR_VERSION;
138 map_["opset_import"] = KeyWord::OPSET_IMPORT;
139 map_["producer_name"] = KeyWord::PRODUCER_NAME;
140 map_["producer_version"] = KeyWord::PRODUCER_VERSION;
141 map_["domain"] = KeyWord::DOMAIN_KW;
142 map_["model_version"] = KeyWord::MODEL_VERSION;
143 map_["doc_string"] = KeyWord::DOC_STRING;
144 map_["metadata_props"] = KeyWord::METADATA_PROPS;
145 map_["seq"] = KeyWord::SEQ_TYPE;
146 map_["map"] = KeyWord::MAP_TYPE;
147 map_["optional"] = KeyWord::OPTIONAL_TYPE;
148 map_["sparse_tensor"] = KeyWord::SPARSE_TENSOR_TYPE;
149 }
150
151 static const std::unordered_map<std::string, KeyWord>& Instance() {
152 static KeyWordMap instance;
153 return instance.map_;
154 }
155
156 static KeyWord Lookup(const std::string& id) {
157 auto it = Instance().find(id);
158 if (it != Instance().end())
159 return it->second;
160 return KeyWord::NONE;
161 }
162
163 static const std::string& ToString(KeyWord kw) {
164 static std::string undefined("undefined");
165 for (const auto& pair : Instance()) {
166 if (pair.second == kw)
167 return pair.first;
168 }
169 return undefined;
170 }
171
172 private:
173 std::unordered_map<std::string, KeyWord> map_;
174};
175
176class ParserBase {
177 public:
178 ParserBase(const std::string& str)
179 : start_(str.data()), next_(str.data()), end_(str.data() + str.length()), saved_pos_(next_) {}
180
181 ParserBase(const char* cstr) : start_(cstr), next_(cstr), end_(cstr + strlen(cstr)), saved_pos_(next_) {}
182
183 void SavePos() {
184 saved_pos_ = next_;
185 }
186
187 void RestorePos() {
188 next_ = saved_pos_;
189 }
190
191 std::string GetCurrentPos() {
192 uint32_t line = 1, col = 1;
193 for (const char* p = start_; p < next_; ++p) {
194 if (*p == '\n') {
195 ++line;
196 col = 1;
197 } else {
198 ++col;
199 }
200 }
201 return ONNX_NAMESPACE::MakeString("(line: ", line, " column: ", col, ")");
202 }
203
204 // Return a suitable suffix of what has been parsed to provide error message context:
205 // return the line containing the last non-space character preceding the error (if it exists).
206 std::string GetErrorContext() {
207 // Special cases: empty input string, and parse-error at first character.
208 const char* p = next_ < end_ ? next_ : next_ - 1;
209 while ((p > start_) && isspace(*p))
210 --p;
211 while ((p > start_) && (*p != '\n'))
212 --p;
213 // Start at character after '\n' unless we are at start of input
214 const char* context_start = (p > start_) ? (p + 1) : start_;
215 for (p = context_start; (p < end_) && (*p != '\n'); ++p)
216 ;
217 return std::string(context_start, p - context_start);
218 }
219
220 template <typename... Args>
221 Status ParseError(const Args&... args) {
222 return Status(
223 NONE,
224 FAIL,
225 ONNX_NAMESPACE::MakeString(
226 "[ParseError at position ", GetCurrentPos(), "]\n", "Error context: ", GetErrorContext(), "\n", args...));
227 }
228
229 void SkipWhiteSpace() {
230 do {
231 while ((next_ < end_) && (isspace(*next_)))
232 ++next_;
233 if ((next_ >= end_) || ((*next_) != '#'))
234 return;
235 // Skip rest of the line:
236 while ((next_ < end_) && ((*next_) != '\n'))
237 ++next_;
238 } while (true);
239 }
240
241 int NextChar(bool skipspace = true) {
242 if (skipspace)
243 SkipWhiteSpace();
244 return (next_ < end_) ? *next_ : 0;
245 }
246
247 bool Matches(char ch, bool skipspace = true) {
248 if (skipspace)
249 SkipWhiteSpace();
250 if ((next_ < end_) && (*next_ == ch)) {
251 ++next_;
252 return true;
253 }
254 return false;
255 }
256
257 Status Match(char ch, bool skipspace = true) {
258 if (!Matches(ch, skipspace))
259 return ParseError("Expected character ", ch, " not found.");
260 return Status::OK();
261 }
262
263 bool EndOfInput() {
264 SkipWhiteSpace();
265 return (next_ >= end_);
266 }
267
268 enum class LiteralType { INT_LITERAL, FLOAT_LITERAL, STRING_LITERAL };
269
270 struct Literal {
271 LiteralType type;
272 std::string value;
273 };
274
275 Status Parse(Literal& result);
276
277 Status Parse(int64_t& val) {
278 Literal literal;
279 CHECK_PARSER_STATUS(Parse(literal));
280 if (literal.type != LiteralType::INT_LITERAL)
281 return ParseError("Integer value expected, but not found.");
282 std::string s = literal.value;
283 val = std::stoll(s);
284 return Status::OK();
285 }
286
287 Status Parse(uint64_t& val) {
288 Literal literal;
289 CHECK_PARSER_STATUS(Parse(literal));
290 if (literal.type != LiteralType::INT_LITERAL)
291 return ParseError("Integer value expected, but not found.");
292 std::string s = literal.value;
293 val = std::stoull(s);
294 return Status::OK();
295 }
296
297 Status Parse(float& val) {
298 Literal literal;
299 CHECK_PARSER_STATUS(Parse(literal));
300 switch (literal.type) {
301 case LiteralType::INT_LITERAL:
302 case LiteralType::FLOAT_LITERAL:
303 val = std::stof(literal.value);
304 break;
305 default:
306 return ParseError("Unexpected literal type.");
307 }
308 return Status::OK();
309 }
310
311 Status Parse(double& val) {
312 Literal literal;
313 CHECK_PARSER_STATUS(Parse(literal));
314 switch (literal.type) {
315 case LiteralType::INT_LITERAL:
316 case LiteralType::FLOAT_LITERAL:
317 val = std::stod(literal.value);
318 break;
319 default:
320 return ParseError("Unexpected literal type.");
321 }
322 return Status::OK();
323 }
324
325 // Parse a string-literal enclosed within doube-quotes.
326 Status Parse(std::string& val) {
327 Literal literal;
328 CHECK_PARSER_STATUS(Parse(literal));
329 if (literal.type != LiteralType::STRING_LITERAL)
330 return ParseError("String value expected, but not found.");
331 val = literal.value;
332 return Status::OK();
333 }
334
335 // Parse an identifier, including keywords. If none found, this will
336 // return an empty-string identifier.
337 Status ParseOptionalIdentifier(std::string& id) {
338 SkipWhiteSpace();
339 auto from = next_;
340 if ((next_ < end_) && (isalpha(*next_) || (*next_ == '_'))) {
341 ++next_;
342 while ((next_ < end_) && (isalnum(*next_) || (*next_ == '_')))
343 ++next_;
344 }
345 id = std::string(from, next_ - from);
346 return Status::OK();
347 }
348
349 Status ParseIdentifier(std::string& id) {
350 ParseOptionalIdentifier(id);
351 if (id.empty())
352 return ParseError("Identifier expected but not found.");
353 return Status::OK();
354 }
355
356 Status PeekIdentifier(std::string& id) {
357 SavePos();
358 ParseOptionalIdentifier(id);
359 RestorePos();
360 return Status::OK();
361 }
362
363 Status Parse(KeyWordMap::KeyWord& keyword) {
364 std::string id;
365 CHECK_PARSER_STATUS(ParseIdentifier(id));
366 keyword = KeyWordMap::Lookup(id);
367 return Status::OK();
368 }
369
370 protected:
371 const char* start_;
372 const char* next_;
373 const char* end_;
374 const char* saved_pos_;
375};
376
377class OnnxParser : public ParserBase {
378 public:
379 OnnxParser(const char* cstr) : ParserBase(cstr) {}
380
381 Status Parse(TensorShapeProto& shape);
382
383 Status Parse(TypeProto& typeProto);
384
385 Status Parse(TensorProto& tensorProto);
386
387 Status Parse(AttributeProto& attr);
388
389 Status Parse(AttrList& attrlist);
390
391 Status Parse(NodeProto& node);
392
393 Status Parse(NodeList& nodelist);
394
395 Status Parse(GraphProto& graph);
396
397 Status Parse(FunctionProto& fn);
398
399 Status Parse(ModelProto& model);
400
401 template <typename T>
402 static Status Parse(T& parsedData, const char* input) {
403 OnnxParser parser(input);
404 return parser.Parse(parsedData);
405 }
406
407 private:
408 Status Parse(std::string name, GraphProto& graph);
409
410 Status Parse(IdList& idlist);
411
412 Status Parse(char open, IdList& idlist, char close);
413
414 Status ParseSingleAttributeValue(AttributeProto& attr);
415
416 Status Parse(ValueInfoProto& valueinfo);
417
418 Status Parse(ValueInfoList& vilist);
419
420 Status ParseInput(ValueInfoList& vilist, TensorList& initializers);
421
422 Status ParseValueInfo(ValueInfoList& vilist, TensorList& initializers);
423
424 Status Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto);
425
426 Status Parse(OpsetIdList& opsets);
427
428 bool NextIsType();
429};
430
431} // namespace ONNX_NAMESPACE