1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // Experimental language syntax and parser for ONNX. Please note that the syntax as formalized |
6 | // by this parser is preliminary and may change. |
7 | |
8 | #pragma once |
9 | |
10 | #include <ctype.h> |
11 | #include <iostream> |
12 | #include <stdexcept> |
13 | #include <string> |
14 | #include <unordered_map> |
15 | |
16 | #include "onnx/onnx_pb.h" |
17 | |
18 | #include "onnx/common/status.h" |
19 | #include "onnx/string_utils.h" |
20 | |
21 | namespace ONNX_NAMESPACE { |
22 | |
23 | using namespace ONNX_NAMESPACE::Common; |
24 | |
25 | using IdList = google::protobuf::RepeatedPtrField<std::string>; |
26 | |
27 | using NodeList = google::protobuf::RepeatedPtrField<NodeProto>; |
28 | |
29 | using AttrList = google::protobuf::RepeatedPtrField<AttributeProto>; |
30 | |
31 | using ValueInfoList = google::protobuf::RepeatedPtrField<ValueInfoProto>; |
32 | |
33 | using TensorList = google::protobuf::RepeatedPtrField<TensorProto>; |
34 | |
35 | using OpsetIdList = google::protobuf::RepeatedPtrField<OperatorSetIdProto>; |
36 | |
37 | #define CHECK_PARSER_STATUS(status) \ |
38 | { \ |
39 | auto local_status_ = status; \ |
40 | if (!local_status_.IsOK()) \ |
41 | return local_status_; \ |
42 | } |
43 | |
44 | template <typename Map> |
45 | class StringIntMap { |
46 | public: |
47 | static const std::unordered_map<std::string, int32_t>& Instance() { |
48 | static Map instance; |
49 | return instance.map_; |
50 | } |
51 | |
52 | static int32_t Lookup(const std::string& dtype) { |
53 | auto it = Instance().find(dtype); |
54 | if (it != Instance().end()) |
55 | return it->second; |
56 | return 0; |
57 | } |
58 | |
59 | static const std::string& ToString(int32_t dtype) { |
60 | static std::string undefined("undefined" ); |
61 | for (const auto& pair : Instance()) { |
62 | if (pair.second == dtype) |
63 | return pair.first; |
64 | } |
65 | return undefined; |
66 | } |
67 | |
68 | protected: |
69 | std::unordered_map<std::string, int32_t> map_; |
70 | }; |
71 | |
72 | class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> { |
73 | public: |
74 | PrimitiveTypeNameMap() : StringIntMap() { |
75 | map_["float" ] = 1; |
76 | map_["uint8" ] = 2; |
77 | map_["int8" ] = 3; |
78 | map_["uint16" ] = 4; |
79 | map_["int16" ] = 5; |
80 | map_["int32" ] = 6; |
81 | map_["int64" ] = 7; |
82 | map_["string" ] = 8; |
83 | map_["bool" ] = 9; |
84 | map_["float16" ] = 10; |
85 | map_["double" ] = 11; |
86 | map_["uint32" ] = 12; |
87 | map_["uint64" ] = 13; |
88 | map_["complex64" ] = 14; |
89 | map_["complex128" ] = 15; |
90 | map_["bfloat16" ] = 16; |
91 | } |
92 | |
93 | static bool IsTypeName(const std::string& dtype) { |
94 | return Lookup(dtype) != 0; |
95 | } |
96 | }; |
97 | |
98 | class AttributeTypeNameMap : public StringIntMap<AttributeTypeNameMap> { |
99 | public: |
100 | AttributeTypeNameMap() : StringIntMap() { |
101 | map_["float" ] = 1; |
102 | map_["int" ] = 2; |
103 | map_["string" ] = 3; |
104 | map_["tensor" ] = 4; |
105 | map_["graph" ] = 5; |
106 | map_["sparse_tensor" ] = 11; |
107 | map_["type_proto" ] = 13; |
108 | map_["floats" ] = 6; |
109 | map_["ints" ] = 7; |
110 | map_["strings" ] = 8; |
111 | map_["tensors" ] = 9; |
112 | map_["graphs" ] = 10; |
113 | map_["sparse_tensors" ] = 12; |
114 | map_["type_protos" ] = 14; |
115 | } |
116 | }; |
117 | |
118 | class KeyWordMap { |
119 | public: |
120 | enum class KeyWord { |
121 | NONE, |
122 | IR_VERSION, |
123 | OPSET_IMPORT, |
124 | PRODUCER_NAME, |
125 | PRODUCER_VERSION, |
126 | DOMAIN_KW, |
127 | MODEL_VERSION, |
128 | DOC_STRING, |
129 | METADATA_PROPS, |
130 | SEQ_TYPE, |
131 | MAP_TYPE, |
132 | OPTIONAL_TYPE, |
133 | SPARSE_TENSOR_TYPE |
134 | }; |
135 | |
136 | KeyWordMap() { |
137 | map_["ir_version" ] = KeyWord::IR_VERSION; |
138 | map_["opset_import" ] = KeyWord::OPSET_IMPORT; |
139 | map_["producer_name" ] = KeyWord::PRODUCER_NAME; |
140 | map_["producer_version" ] = KeyWord::PRODUCER_VERSION; |
141 | map_["domain" ] = KeyWord::DOMAIN_KW; |
142 | map_["model_version" ] = KeyWord::MODEL_VERSION; |
143 | map_["doc_string" ] = KeyWord::DOC_STRING; |
144 | map_["metadata_props" ] = KeyWord::METADATA_PROPS; |
145 | map_["seq" ] = KeyWord::SEQ_TYPE; |
146 | map_["map" ] = KeyWord::MAP_TYPE; |
147 | map_["optional" ] = KeyWord::OPTIONAL_TYPE; |
148 | map_["sparse_tensor" ] = KeyWord::SPARSE_TENSOR_TYPE; |
149 | } |
150 | |
151 | static const std::unordered_map<std::string, KeyWord>& Instance() { |
152 | static KeyWordMap instance; |
153 | return instance.map_; |
154 | } |
155 | |
156 | static KeyWord Lookup(const std::string& id) { |
157 | auto it = Instance().find(id); |
158 | if (it != Instance().end()) |
159 | return it->second; |
160 | return KeyWord::NONE; |
161 | } |
162 | |
163 | static const std::string& ToString(KeyWord kw) { |
164 | static std::string undefined("undefined" ); |
165 | for (const auto& pair : Instance()) { |
166 | if (pair.second == kw) |
167 | return pair.first; |
168 | } |
169 | return undefined; |
170 | } |
171 | |
172 | private: |
173 | std::unordered_map<std::string, KeyWord> map_; |
174 | }; |
175 | |
176 | class ParserBase { |
177 | public: |
178 | ParserBase(const std::string& str) |
179 | : start_(str.data()), next_(str.data()), end_(str.data() + str.length()), saved_pos_(next_) {} |
180 | |
181 | ParserBase(const char* cstr) : start_(cstr), next_(cstr), end_(cstr + strlen(cstr)), saved_pos_(next_) {} |
182 | |
183 | void SavePos() { |
184 | saved_pos_ = next_; |
185 | } |
186 | |
187 | void RestorePos() { |
188 | next_ = saved_pos_; |
189 | } |
190 | |
191 | std::string GetCurrentPos() { |
192 | uint32_t line = 1, col = 1; |
193 | for (const char* p = start_; p < next_; ++p) { |
194 | if (*p == '\n') { |
195 | ++line; |
196 | col = 1; |
197 | } else { |
198 | ++col; |
199 | } |
200 | } |
201 | return ONNX_NAMESPACE::MakeString("(line: " , line, " column: " , col, ")" ); |
202 | } |
203 | |
204 | // Return a suitable suffix of what has been parsed to provide error message context: |
205 | // return the line containing the last non-space character preceding the error (if it exists). |
206 | std::string GetErrorContext() { |
207 | // Special cases: empty input string, and parse-error at first character. |
208 | const char* p = next_ < end_ ? next_ : next_ - 1; |
209 | while ((p > start_) && isspace(*p)) |
210 | --p; |
211 | while ((p > start_) && (*p != '\n')) |
212 | --p; |
213 | // Start at character after '\n' unless we are at start of input |
214 | const char* context_start = (p > start_) ? (p + 1) : start_; |
215 | for (p = context_start; (p < end_) && (*p != '\n'); ++p) |
216 | ; |
217 | return std::string(context_start, p - context_start); |
218 | } |
219 | |
220 | template <typename... Args> |
221 | Status ParseError(const Args&... args) { |
222 | return Status( |
223 | NONE, |
224 | FAIL, |
225 | ONNX_NAMESPACE::MakeString( |
226 | "[ParseError at position " , GetCurrentPos(), "]\n" , "Error context: " , GetErrorContext(), "\n" , args...)); |
227 | } |
228 | |
229 | void SkipWhiteSpace() { |
230 | do { |
231 | while ((next_ < end_) && (isspace(*next_))) |
232 | ++next_; |
233 | if ((next_ >= end_) || ((*next_) != '#')) |
234 | return; |
235 | // Skip rest of the line: |
236 | while ((next_ < end_) && ((*next_) != '\n')) |
237 | ++next_; |
238 | } while (true); |
239 | } |
240 | |
241 | int NextChar(bool skipspace = true) { |
242 | if (skipspace) |
243 | SkipWhiteSpace(); |
244 | return (next_ < end_) ? *next_ : 0; |
245 | } |
246 | |
247 | bool Matches(char ch, bool skipspace = true) { |
248 | if (skipspace) |
249 | SkipWhiteSpace(); |
250 | if ((next_ < end_) && (*next_ == ch)) { |
251 | ++next_; |
252 | return true; |
253 | } |
254 | return false; |
255 | } |
256 | |
257 | Status Match(char ch, bool skipspace = true) { |
258 | if (!Matches(ch, skipspace)) |
259 | return ParseError("Expected character " , ch, " not found." ); |
260 | return Status::OK(); |
261 | } |
262 | |
263 | bool EndOfInput() { |
264 | SkipWhiteSpace(); |
265 | return (next_ >= end_); |
266 | } |
267 | |
268 | enum class LiteralType { INT_LITERAL, FLOAT_LITERAL, STRING_LITERAL }; |
269 | |
270 | struct Literal { |
271 | LiteralType type; |
272 | std::string value; |
273 | }; |
274 | |
275 | Status Parse(Literal& result); |
276 | |
277 | Status Parse(int64_t& val) { |
278 | Literal literal; |
279 | CHECK_PARSER_STATUS(Parse(literal)); |
280 | if (literal.type != LiteralType::INT_LITERAL) |
281 | return ParseError("Integer value expected, but not found." ); |
282 | std::string s = literal.value; |
283 | val = std::stoll(s); |
284 | return Status::OK(); |
285 | } |
286 | |
287 | Status Parse(uint64_t& val) { |
288 | Literal literal; |
289 | CHECK_PARSER_STATUS(Parse(literal)); |
290 | if (literal.type != LiteralType::INT_LITERAL) |
291 | return ParseError("Integer value expected, but not found." ); |
292 | std::string s = literal.value; |
293 | val = std::stoull(s); |
294 | return Status::OK(); |
295 | } |
296 | |
297 | Status Parse(float& val) { |
298 | Literal literal; |
299 | CHECK_PARSER_STATUS(Parse(literal)); |
300 | switch (literal.type) { |
301 | case LiteralType::INT_LITERAL: |
302 | case LiteralType::FLOAT_LITERAL: |
303 | val = std::stof(literal.value); |
304 | break; |
305 | default: |
306 | return ParseError("Unexpected literal type." ); |
307 | } |
308 | return Status::OK(); |
309 | } |
310 | |
311 | Status Parse(double& val) { |
312 | Literal literal; |
313 | CHECK_PARSER_STATUS(Parse(literal)); |
314 | switch (literal.type) { |
315 | case LiteralType::INT_LITERAL: |
316 | case LiteralType::FLOAT_LITERAL: |
317 | val = std::stod(literal.value); |
318 | break; |
319 | default: |
320 | return ParseError("Unexpected literal type." ); |
321 | } |
322 | return Status::OK(); |
323 | } |
324 | |
325 | // Parse a string-literal enclosed within doube-quotes. |
326 | Status Parse(std::string& val) { |
327 | Literal literal; |
328 | CHECK_PARSER_STATUS(Parse(literal)); |
329 | if (literal.type != LiteralType::STRING_LITERAL) |
330 | return ParseError("String value expected, but not found." ); |
331 | val = literal.value; |
332 | return Status::OK(); |
333 | } |
334 | |
335 | // Parse an identifier, including keywords. If none found, this will |
336 | // return an empty-string identifier. |
337 | Status ParseOptionalIdentifier(std::string& id) { |
338 | SkipWhiteSpace(); |
339 | auto from = next_; |
340 | if ((next_ < end_) && (isalpha(*next_) || (*next_ == '_'))) { |
341 | ++next_; |
342 | while ((next_ < end_) && (isalnum(*next_) || (*next_ == '_'))) |
343 | ++next_; |
344 | } |
345 | id = std::string(from, next_ - from); |
346 | return Status::OK(); |
347 | } |
348 | |
349 | Status ParseIdentifier(std::string& id) { |
350 | ParseOptionalIdentifier(id); |
351 | if (id.empty()) |
352 | return ParseError("Identifier expected but not found." ); |
353 | return Status::OK(); |
354 | } |
355 | |
356 | Status PeekIdentifier(std::string& id) { |
357 | SavePos(); |
358 | ParseOptionalIdentifier(id); |
359 | RestorePos(); |
360 | return Status::OK(); |
361 | } |
362 | |
363 | Status Parse(KeyWordMap::KeyWord& keyword) { |
364 | std::string id; |
365 | CHECK_PARSER_STATUS(ParseIdentifier(id)); |
366 | keyword = KeyWordMap::Lookup(id); |
367 | return Status::OK(); |
368 | } |
369 | |
370 | protected: |
371 | const char* start_; |
372 | const char* next_; |
373 | const char* end_; |
374 | const char* saved_pos_; |
375 | }; |
376 | |
377 | class OnnxParser : public ParserBase { |
378 | public: |
379 | OnnxParser(const char* cstr) : ParserBase(cstr) {} |
380 | |
381 | Status Parse(TensorShapeProto& shape); |
382 | |
383 | Status Parse(TypeProto& typeProto); |
384 | |
385 | Status Parse(TensorProto& tensorProto); |
386 | |
387 | Status Parse(AttributeProto& attr); |
388 | |
389 | Status Parse(AttrList& attrlist); |
390 | |
391 | Status Parse(NodeProto& node); |
392 | |
393 | Status Parse(NodeList& nodelist); |
394 | |
395 | Status Parse(GraphProto& graph); |
396 | |
397 | Status Parse(FunctionProto& fn); |
398 | |
399 | Status Parse(ModelProto& model); |
400 | |
401 | template <typename T> |
402 | static Status Parse(T& parsedData, const char* input) { |
403 | OnnxParser parser(input); |
404 | return parser.Parse(parsedData); |
405 | } |
406 | |
407 | private: |
408 | Status Parse(std::string name, GraphProto& graph); |
409 | |
410 | Status Parse(IdList& idlist); |
411 | |
412 | Status Parse(char open, IdList& idlist, char close); |
413 | |
414 | Status ParseSingleAttributeValue(AttributeProto& attr); |
415 | |
416 | Status Parse(ValueInfoProto& valueinfo); |
417 | |
418 | Status Parse(ValueInfoList& vilist); |
419 | |
420 | Status ParseInput(ValueInfoList& vilist, TensorList& initializers); |
421 | |
422 | Status ParseValueInfo(ValueInfoList& vilist, TensorList& initializers); |
423 | |
424 | Status Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto); |
425 | |
426 | Status Parse(OpsetIdList& opsets); |
427 | |
428 | bool NextIsType(); |
429 | }; |
430 | |
431 | } // namespace ONNX_NAMESPACE |