1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "onnx/checker.h" |
6 | #include "onnx/common/file_utils.h" |
7 | #include "onnx/common/path.h" |
8 | #include "onnx/defs/schema.h" |
9 | #include "onnx/defs/tensor_proto_util.h" |
10 | #include "onnx/proto_utils.h" |
11 | #include "onnx/shape_inference/implementation.h" |
12 | #include "onnx/string_utils.h" |
13 | |
14 | #include <fstream> |
15 | #include <iterator> |
16 | #include <unordered_set> |
17 | |
18 | #ifdef _WIN32 |
19 | #include <direct.h> |
20 | #include <filesystem> |
21 | |
22 | #else // POSIX |
23 | #include <sys/stat.h> |
24 | #endif |
25 | |
26 | namespace ONNX_NAMESPACE { |
27 | namespace checker { |
28 | |
29 | #define enforce_has_field(proto, field) \ |
30 | do { \ |
31 | if (!proto.has_##field()) { \ |
32 | fail_check("Field '", #field, "' of '", #proto, "' is required but missing."); \ |
33 | } \ |
34 | } while (0) |
35 | |
36 | #define enforce_has_repeated_field(proto, field) \ |
37 | do { \ |
38 | if (!proto.field##_size()) { \ |
39 | fail_check("Repeated Field '", #field, "' of '", #proto, "' is required but missing."); \ |
40 | } \ |
41 | } while (0) |
42 | |
43 | #define enforce_non_empty_field(proto, field) \ |
44 | do { \ |
45 | if (proto.field().empty()) { \ |
46 | fail_check("Field '", #field, "' of '", #proto, "' is required to be non-empty."); \ |
47 | } \ |
48 | } while (0) |
49 | |
50 | void check_value_info(const ValueInfoProto& value_info, const CheckerContext& ctx) { |
51 | enforce_non_empty_field(value_info, name); |
52 | // Relax constraint for subgraph input/output. |
53 | if (!ctx.is_main_graph()) |
54 | return; |
55 | enforce_has_field(value_info, type); |
56 | const auto value_case = value_info.type().value_case(); |
57 | switch (value_case) { |
58 | case TypeProto::kTensorType: { |
59 | const auto& type = value_info.type().tensor_type(); |
60 | enforce_has_field(type, elem_type); |
61 | enforce_has_field(type, shape); |
62 | } break; |
63 | case TypeProto::kOptionalType: { |
64 | const auto& type = value_info.type().optional_type(); |
65 | enforce_has_field(type, elem_type); |
66 | } break; |
67 | case TypeProto::kSequenceType: { |
68 | const auto& type = value_info.type().sequence_type(); |
69 | enforce_has_field(type, elem_type); |
70 | } break; |
71 | case TypeProto::kMapType: { |
72 | const auto& type = value_info.type().map_type(); |
73 | enforce_has_field(type, key_type); |
74 | enforce_has_field(type, value_type); |
75 | } break; |
76 | #ifdef ONNX_ML |
77 | case TypeProto::kOpaqueType: |
78 | break; |
79 | #endif |
80 | case TypeProto::kSparseTensorType: { |
81 | const auto& type = value_info.type().sparse_tensor_type(); |
82 | enforce_has_field(type, elem_type); |
83 | enforce_has_field(type, shape); |
84 | } break; |
85 | |
86 | default: |
87 | fail_check("Unrecognized type value case (value_info name: " , value_info.name(), "): " , value_case); |
88 | } |
89 | } |
90 | |
91 | void check_tensor(const TensorProto& tensor, const CheckerContext& ctx) { |
92 | enforce_has_field(tensor, data_type); |
93 | if (tensor.data_type() == TensorProto::UNDEFINED) { |
94 | fail_check("setting data_type field (tensor name: " , tensor.name(), ") to UNDEFINED is not allowed" ); |
95 | } |
96 | |
97 | int num_value_fields = 0; |
98 | |
99 | const char* value_field = nullptr; |
100 | |
101 | #define check_data_field(field) \ |
102 | bool has_##field = tensor.field().size(); \ |
103 | if (has_##field) { \ |
104 | ++num_value_fields; \ |
105 | value_field = #field; \ |
106 | } |
107 | |
108 | check_data_field(float_data); |
109 | check_data_field(int32_data); |
110 | check_data_field(string_data); |
111 | check_data_field(int64_data); |
112 | check_data_field(raw_data); |
113 | check_data_field(double_data); |
114 | check_data_field(uint64_data); |
115 | |
116 | #undef check_data_field |
117 | |
118 | bool stored_externally = tensor.has_data_location() && tensor.data_location() == TensorProto::EXTERNAL; |
119 | if (stored_externally) { |
120 | if (num_value_fields != 0) { |
121 | fail_check( |
122 | "Data of TensorProto ( tensor name: " , |
123 | tensor.name(), |
124 | ") is stored externally and should not have data field." , |
125 | value_field); |
126 | } |
127 | |
128 | bool has_location = false; |
129 | for (const StringStringEntryProto& entry : tensor.external_data()) { |
130 | if (entry.has_key() && entry.has_value() && entry.key() == "location" ) { |
131 | has_location = true; |
132 | #ifdef _WIN32 |
133 | auto file_path = std::filesystem::path(utf8str_to_wstring(entry.value())); |
134 | if (file_path.is_absolute()) { |
135 | fail_check( |
136 | "Location of external TensorProto ( tensor name: " , |
137 | tensor.name(), |
138 | ") should be a relative path, but it is an absolute path: " , |
139 | entry.value()); |
140 | } |
141 | auto relative_path = file_path.lexically_normal().make_preferred().wstring(); |
142 | // Check that normalized relative path contains ".." on Windows. |
143 | if (relative_path.find(L".." , 0) != std::string::npos) { |
144 | fail_check( |
145 | "Data of TensorProto ( tensor name: " , |
146 | tensor.name(), |
147 | ") should be file inside the " , |
148 | ctx.get_model_dir(), |
149 | ", but the '" , |
150 | entry.value(), |
151 | "' points outside the directory" ); |
152 | } |
153 | std::wstring data_path = path_join(utf8str_to_wstring(ctx.get_model_dir()), relative_path); |
154 | struct _stat buff; |
155 | if (_wstat(data_path.c_str(), &buff) != 0) { |
156 | fail_check( |
157 | "Data of TensorProto ( tensor name: " , |
158 | tensor.name(), |
159 | ") should be stored in " , |
160 | entry.value(), |
161 | ", but it doesn't exist or is not accessible." ); |
162 | } |
163 | #else // POSIX |
164 | if (entry.value().empty()) { |
165 | fail_check("Location of external TensorProto ( tensor name: " , tensor.name(), ") should not be empty." ); |
166 | } else if (entry.value()[0] == '/') { |
167 | fail_check( |
168 | "Location of external TensorProto ( tensor name: " , |
169 | tensor.name(), |
170 | ") should be a relative path, but it is an absolute path: " , |
171 | entry.value()); |
172 | } |
173 | std::string relative_path = clean_relative_path(entry.value()); |
174 | // Check that normalized relative path contains ".." on POSIX |
175 | if (relative_path.find(".." , 0) != std::string::npos) { |
176 | fail_check( |
177 | "Data of TensorProto ( tensor name: " , |
178 | tensor.name(), |
179 | ") should be file inside the " , |
180 | ctx.get_model_dir(), |
181 | ", but the '" , |
182 | entry.value(), |
183 | "' points outside the directory" ); |
184 | } |
185 | std::string data_path = path_join(ctx.get_model_dir(), relative_path); |
186 | // use stat to check whether the file exists |
187 | struct stat buffer; |
188 | if (stat((data_path).c_str(), &buffer) != 0) { |
189 | fail_check( |
190 | "Data of TensorProto ( tensor name: " , |
191 | tensor.name(), |
192 | ") should be stored in " , |
193 | data_path, |
194 | ", but it doesn't exist or is not accessible." ); |
195 | } |
196 | // Do not allow symlinks or directories. |
197 | if (!S_ISREG(buffer.st_mode)) { |
198 | fail_check( |
199 | "Data of TensorProto ( tensor name: " , |
200 | tensor.name(), |
201 | ") should be stored in " , |
202 | data_path, |
203 | ", but it is not regular file." ); |
204 | } |
205 | #endif |
206 | } |
207 | } |
208 | if (!has_location) { |
209 | fail_check("TensorProto ( tensor name: " , tensor.name(), ") is stored externally but doesn't have a location." ); |
210 | } |
211 | return; |
212 | } |
213 | int64_t nelem = 1; |
214 | for (auto x : tensor.dims()) { |
215 | nelem *= x; |
216 | } |
217 | if (nelem == 0 && num_value_fields != 0) { |
218 | fail_check("TensorProto (tensor name: " , tensor.name(), ") is 0-element but contains data!" ); |
219 | } |
220 | if (nelem != 0 && num_value_fields != 1) { |
221 | fail_check("TensorProto (tensor name: " , tensor.name(), ") should contain one and only one value field." ); |
222 | } |
223 | if (has_raw_data) { |
224 | if (tensor.data_type() == TensorProto::STRING) { |
225 | fail_check("STRING data (tensor name: " , tensor.name(), ") should not be stored in raw_data field" ); |
226 | } |
227 | return; |
228 | } else { |
229 | #define check_field(field) \ |
230 | if (nelem != 0 && !has_##field) { \ |
231 | fail_check( \ |
232 | "values of data_type '", \ |
233 | tensor.data_type(), \ |
234 | "' should be stored in field '", \ |
235 | #field, \ |
236 | "' instead of '", \ |
237 | value_field, \ |
238 | "'"); \ |
239 | } |
240 | |
241 | switch (tensor.data_type()) { |
242 | case TensorProto::FLOAT: |
243 | case TensorProto::COMPLEX64: |
244 | check_field(float_data); |
245 | break; |
246 | |
247 | case TensorProto::DOUBLE: |
248 | case TensorProto::COMPLEX128: |
249 | check_field(double_data); |
250 | break; |
251 | |
252 | case TensorProto::INT32: |
253 | case TensorProto::UINT8: |
254 | case TensorProto::INT8: |
255 | case TensorProto::UINT16: |
256 | case TensorProto::INT16: |
257 | case TensorProto::BOOL: |
258 | case TensorProto::FLOAT16: |
259 | case TensorProto::BFLOAT16: |
260 | check_field(int32_data); |
261 | break; |
262 | |
263 | case TensorProto::INT64: |
264 | check_field(int64_data); |
265 | break; |
266 | |
267 | case TensorProto::UINT32: |
268 | case TensorProto::UINT64: |
269 | check_field(uint64_data); |
270 | break; |
271 | |
272 | case TensorProto::STRING: |
273 | check_field(string_data); |
274 | break; |
275 | |
276 | default: |
277 | fail_check("Unrecognized data_type (tensor name: " , tensor.name(), "): " , tensor.data_type()); |
278 | } |
279 | } |
280 | |
281 | #undef check_field |
282 | } |
283 | |
284 | void check_sequence(const SequenceProto& sequence, const CheckerContext& ctx) { |
285 | enforce_has_field(sequence, elem_type); |
286 | if (sequence.elem_type() == SequenceProto::TENSOR) { |
287 | for (const TensorProto& tensor : sequence.tensor_values()) { |
288 | check_tensor(tensor, ctx); |
289 | } |
290 | } else if (sequence.elem_type() == SequenceProto::SPARSE_TENSOR) { |
291 | for (const SparseTensorProto& sparse_tensor : sequence.sparse_tensor_values()) { |
292 | check_sparse_tensor(sparse_tensor, ctx); |
293 | } |
294 | } else if (sequence.elem_type() == SequenceProto::SEQUENCE) { |
295 | for (const SequenceProto& seq : sequence.sequence_values()) { |
296 | check_sequence(seq, ctx); |
297 | } |
298 | } else if (sequence.elem_type() == SequenceProto::MAP) { |
299 | for (const MapProto& map : sequence.map_values()) { |
300 | check_map(map, ctx); |
301 | } |
302 | } else { |
303 | fail_check( |
304 | "Sequence ( Structure name: " , |
305 | sequence.name(), |
306 | ", elem_type: " , |
307 | sequence.elem_type(), |
308 | ") is not have a valid element type." ); |
309 | } |
310 | } |
311 | |
312 | void check_optional(const OptionalProto& optional, const CheckerContext& ctx) { |
313 | enforce_has_field(optional, elem_type); |
314 | if (optional.elem_type() == OptionalProto::UNDEFINED) { |
315 | return; |
316 | } else if (optional.elem_type() == OptionalProto::TENSOR) { |
317 | if (optional.has_tensor_value()) |
318 | check_tensor(optional.tensor_value(), ctx); |
319 | } else if (optional.elem_type() == OptionalProto::SPARSE_TENSOR) { |
320 | if (optional.has_sparse_tensor_value()) |
321 | check_sparse_tensor(optional.sparse_tensor_value(), ctx); |
322 | } else if (optional.elem_type() == OptionalProto::SEQUENCE) { |
323 | if (optional.has_sequence_value()) |
324 | check_sequence(optional.sequence_value(), ctx); |
325 | } else if (optional.elem_type() == OptionalProto::MAP) { |
326 | if (optional.has_map_value()) |
327 | check_map(optional.map_value(), ctx); |
328 | } else { |
329 | fail_check( |
330 | "Optional ( Structure name: " , |
331 | optional.name(), |
332 | ", elem_type: " , |
333 | optional.elem_type(), |
334 | ") is not have a valid element type." ); |
335 | } |
336 | } |
337 | |
338 | void check_map(const MapProto& map, const CheckerContext& ctx) { |
339 | enforce_has_field(map, key_type); |
340 | if (map.key_type() == TensorProto::UNDEFINED) { |
341 | fail_check("setting key_type field (map name: " , map.name(), ") to UNDEFINED is not allowed" ); |
342 | } |
343 | // Check if key is a valid type, specifically INT8, INT16, INT32, INT64, |
344 | // UINT8, UINT16, UINT32, UINT64, or STRING. |
345 | if ((map.key_type() == TensorProto::FLOAT) || (map.key_type() == TensorProto::BOOL) || |
346 | (map.key_type() == TensorProto::FLOAT16) || (map.key_type() == TensorProto::COMPLEX64) || |
347 | (map.key_type() == TensorProto::COMPLEX128)) { |
348 | fail_check( |
349 | "setting key_type field (map name: " , |
350 | map.name(), |
351 | ") to invalid TensorProto key_type " , |
352 | map.key_type(), |
353 | " is not allowed" ); |
354 | } |
355 | |
356 | // MapProto will use either keys or string_keys, so only one should be > 0. |
357 | if ((map.keys_size() > 0) && (map.string_keys_size() > 0)) { |
358 | fail_check("Map (name: " , map.name(), ") should not contain more than one keys field." ); |
359 | } |
360 | |
361 | int num_keys = map.keys_size() + map.string_keys_size(); |
362 | int num_values = 0; |
363 | |
364 | enforce_has_field(map, values); |
365 | check_sequence(map.values(), ctx); |
366 | |
367 | if (map.values().elem_type() == SequenceProto::TENSOR) { |
368 | num_values = map.values().tensor_values_size(); |
369 | } else if (map.values().elem_type() == SequenceProto::SPARSE_TENSOR) { |
370 | num_values = map.values().sparse_tensor_values_size(); |
371 | } else if (map.values().elem_type() == SequenceProto::SEQUENCE) { |
372 | num_values = map.values().sequence_values_size(); |
373 | } else if (map.values().elem_type() == SequenceProto::MAP) { |
374 | num_values = map.values().map_values_size(); |
375 | } |
376 | |
377 | if (num_keys != num_values) { |
378 | fail_check("Length of map keys and map values are not the same (map name: " , map.name(), ")" ); |
379 | } |
380 | } |
381 | |
382 | // Check that the index data stored in a SparseTensorProto is valid. |
383 | // indices: a 1-dimensional tensor; indices[i] represents the |
384 | // linearized index value for the i-th nonzero value. |
385 | void check_sparse_tensor_indices_1( |
386 | const TensorProto& indices, |
387 | const SparseTensorProto& sparse_tensor_proto, |
388 | size_t nnz) { |
389 | int dense_rank = sparse_tensor_proto.dims_size(); |
390 | int64_t dense_size = 1; |
391 | for (int i = 0; i < dense_rank; ++i) |
392 | dense_size *= sparse_tensor_proto.dims(i); |
393 | if (static_cast<size_t>(indices.dims(0)) != nnz) { |
394 | fail_check("Sparse tensor indices (" , indices.name(), ") has " , indices.dims(0), " values, but NNZ is " , nnz); |
395 | } |
396 | |
397 | // Check if indices appear in ascending order, and if they have valid |
398 | // values. The i-th value in index_data is the linear index of the i-th |
399 | // non-zero value. |
400 | const std::vector<int64_t> index_data = ParseData<int64_t>(&indices); |
401 | |
402 | int64_t prev_index = -1; |
403 | for (size_t i = 0; i < nnz; ++i) { |
404 | int64_t curr_index = index_data[i]; // linearized index of i-th value |
405 | if (curr_index < 0 || curr_index >= dense_size) { |
406 | fail_check( |
407 | "Sparse tensor (" , |
408 | indices.name(), |
409 | ") index value at position [" , |
410 | i, |
411 | "] out of range [0, " , |
412 | dense_size - 1, |
413 | "]" ); |
414 | } |
415 | if (curr_index <= prev_index) { |
416 | fail_check("Sparse tensor (" , indices.name(), ") index value at position [" , i, "] not in sorted order." ); |
417 | } |
418 | prev_index = curr_index; |
419 | } |
420 | } |
421 | |
422 | // Check that the index data stored in a SparseTensorProto is valid. |
423 | // indices: a 2-dimensional tensor; indices[i,j] represents the j-th |
424 | // index value for the i-th nonzero value. |
425 | void check_sparse_tensor_indices_2( |
426 | const TensorProto& indices, |
427 | const SparseTensorProto& sparse_tensor_proto, |
428 | size_t nnz) { |
429 | int dense_rank = sparse_tensor_proto.dims_size(); |
430 | if (static_cast<size_t>(indices.dims(0)) != nnz) { |
431 | fail_check("Sparse tensor indices (" , indices.name(), ") first dimension size does not equal NNZ." ); |
432 | } |
433 | if (indices.dims(1) != dense_rank) { |
434 | fail_check("Sparse tensor indices (" , indices.name(), ") second dimension size does not match rank of tensor." ); |
435 | } |
436 | |
437 | // Check if indices appear in ascending order, and if they have valid |
438 | // values. |
439 | const std::vector<int64_t> index_data = ParseData<int64_t>(&indices); |
440 | int64_t prev_index = -1; |
441 | for (size_t i = 0; i < nnz; ++i) { |
442 | int64_t curr_index = 0; // linearized index of i-th value |
443 | for (int j = 0; j < dense_rank; ++j) { |
444 | auto index_ij = index_data[i * dense_rank + j]; |
445 | if ((index_ij < 0) || (index_ij >= sparse_tensor_proto.dims(j))) { |
446 | fail_check("Sparse tensor (" , indices.name(), ") index value at position [" , i, "," , j, "] out of range." ); |
447 | } |
448 | curr_index = curr_index * sparse_tensor_proto.dims(j) + index_ij; |
449 | } |
450 | if (curr_index <= prev_index) { |
451 | fail_check( |
452 | "Sparse tensor (" , indices.name(), ") index value at position [" , i, "] not in lexicographic sorted order." ); |
453 | } |
454 | prev_index = curr_index; |
455 | } |
456 | } |
457 | |
458 | void check_sparse_tensor(const SparseTensorProto& sparse_tensor_proto, const CheckerContext& ctx) { |
459 | enforce_has_field(sparse_tensor_proto, values); |
460 | |
461 | const TensorProto& values = sparse_tensor_proto.values(); |
462 | check_tensor(values, ctx); |
463 | |
464 | // values must be a tensor of shape [NNZ] |
465 | // Currently we restrict the value associated with a particular index-tuple |
466 | // to be a single value. In the future, if there is a requirement, |
467 | // we may extend this to permit the value to be a "sub-tensor", in which |
468 | // case values will have dimension > 1. |
469 | if (values.dims_size() != 1) { |
470 | fail_check("Sparse tensor values (" , values.name(), ") must have rank 1." ); |
471 | } |
472 | size_t nnz = static_cast<size_t>(values.dims(0)); |
473 | int dense_rank = sparse_tensor_proto.dims_size(); |
474 | if (dense_rank == 0) { |
475 | fail_check("Sparse tensor (" , values.name(), ") must have a dense-rank > 0" ); |
476 | } |
477 | for (int i = 0; i < dense_rank; ++i) { |
478 | if (sparse_tensor_proto.dims(i) <= 0) { |
479 | fail_check("Sparse tensor (" , values.name(), ") dimensions are not positive." ); |
480 | } |
481 | } |
482 | |
483 | if (sparse_tensor_proto.has_indices()) { |
484 | const TensorProto& indices = sparse_tensor_proto.indices(); |
485 | check_tensor(indices, ctx); |
486 | if (indices.data_type() != TensorProto::INT64) { |
487 | fail_check("Sparse tensor indices (" , indices.name(), ") must have INT64 type." ); |
488 | } |
489 | switch (indices.dims().size()) { |
490 | case 1: |
491 | // Indices in linearized format |
492 | check_sparse_tensor_indices_1(indices, sparse_tensor_proto, nnz); |
493 | return; |
494 | case 2: |
495 | // Check COO-style index. E.g., an index for a 3D tensor is a 3-tuple. |
496 | check_sparse_tensor_indices_2(indices, sparse_tensor_proto, nnz); |
497 | return; |
498 | default: |
499 | fail_check("Sparse tensor indices (" , indices.name(), ") must have rank 1 or 2." ); |
500 | } |
501 | } else if (nnz != 0) { |
502 | fail_check("Sparse tensor (" , values.name(), ") has no index values." ); |
503 | } |
504 | } |
505 | |
506 | // NB: This is a generic "attribute well-formedness" check, it doesn't |
507 | // actually test if an attribute is valid per a schema |
508 | void check_attribute(const AttributeProto& attr, const CheckerContext& ctx, const LexicalScopeContext& lex_ctx) { |
509 | enforce_non_empty_field(attr, name); |
510 | |
511 | if (ctx.get_ir_version() >= 0x00000002) { |
512 | enforce_has_field(attr, type); |
513 | } |
514 | |
515 | int used_fields = 0; |
516 | |
517 | #define check_type(expected_type) \ |
518 | if (attr.has_type() && attr.type() != expected_type) { \ |
519 | fail_check("type field and data field mismatch in attribute ", attr.name(), "."); \ |
520 | } |
521 | |
522 | #define check_singular_field(field, type) \ |
523 | if (attr.has_##field()) { \ |
524 | ++used_fields; \ |
525 | check_type(type); \ |
526 | } |
527 | |
528 | #define check_repeated_field(field, type) \ |
529 | if (attr.field##_size() > 0) { \ |
530 | ++used_fields; \ |
531 | check_type(type); \ |
532 | } |
533 | |
534 | check_singular_field(f, AttributeProto::FLOAT); |
535 | check_singular_field(i, AttributeProto::INT); |
536 | check_singular_field(s, AttributeProto::STRING); |
537 | check_singular_field(t, AttributeProto::TENSOR); |
538 | check_singular_field(g, AttributeProto::GRAPH); |
539 | check_singular_field(tp, AttributeProto::TYPE_PROTO); |
540 | check_singular_field(sparse_tensor, AttributeProto::SPARSE_TENSOR); |
541 | check_repeated_field(floats, AttributeProto::FLOATS); |
542 | check_repeated_field(ints, AttributeProto::INTS); |
543 | check_repeated_field(strings, AttributeProto::STRINGS); |
544 | check_repeated_field(tensors, AttributeProto::TENSORS); |
545 | check_repeated_field(graphs, AttributeProto::GRAPHS); |
546 | check_repeated_field(sparse_tensors, AttributeProto::SPARSE_TENSORS); |
547 | check_repeated_field(type_protos, AttributeProto::TYPE_PROTOS); |
548 | |
549 | #undef check_type |
550 | #undef check_singular_field |
551 | #undef check_repeated_field |
552 | |
553 | // Normally, used_fields is expected to be 1. |
554 | // In proto3, when the value to be set is type default value (say 0 for |
555 | // int), used_fields may be 0. |
556 | if (used_fields > 1) { |
557 | fail_check("Attribute (name: " , attr.name(), ") should not contain more than one value field." ); |
558 | } |
559 | |
560 | if (!ctx.is_main_graph()) { |
561 | // It's an attribute of a node in function body. |
562 | if (attr.has_ref_attr_name() && used_fields != 0) { |
563 | // The attribute proto is supposed to refer to data outside and does not |
564 | // have its own value field set. |
565 | fail_check("Attribute (name: " , attr.name(), ") should refer to attribute in parent node." ); |
566 | } |
567 | } |
568 | |
569 | if (attr.has_t()) { |
570 | check_tensor(attr.t(), ctx); |
571 | } |
572 | |
573 | if (attr.has_sparse_tensor()) { |
574 | check_sparse_tensor(attr.sparse_tensor(), ctx); |
575 | } |
576 | |
577 | if (attr.has_g()) { |
578 | CheckerContext subgraph_ctx(ctx); |
579 | subgraph_ctx.set_is_main_graph(false); |
580 | check_graph(attr.g(), subgraph_ctx, lex_ctx); |
581 | } |
582 | |
583 | for (const auto& tensor : attr.tensors()) { |
584 | check_tensor(tensor, ctx); |
585 | } |
586 | for (const auto& sparse_tensor : attr.sparse_tensors()) { |
587 | check_sparse_tensor(sparse_tensor, ctx); |
588 | } |
589 | if (attr.graphs().size() > 0) { |
590 | CheckerContext subgraph_ctx(ctx); |
591 | subgraph_ctx.set_is_main_graph(false); |
592 | for (const auto& graph : attr.graphs()) { |
593 | check_graph(graph, subgraph_ctx, lex_ctx); |
594 | } |
595 | } |
596 | } |
597 | |
598 | void print_warning_if_has_experimental(const std::unordered_set<std::string>& used_experimental_ops) { |
599 | if (!used_experimental_ops.empty()) { |
600 | std::string all_experimental_ops; |
601 | for (const auto& op : used_experimental_ops) { |
602 | all_experimental_ops += " " + op + "," ; |
603 | } |
604 | // Remove the last comma which is unnecessary |
605 | all_experimental_ops.pop_back(); |
606 | std::cout << "Warning: Model contains experimental ops:" + all_experimental_ops << std::endl; |
607 | } |
608 | } |
609 | |
610 | void check_node(const NodeProto& node, const CheckerContext& ctx, const LexicalScopeContext& lex_ctx) { |
611 | enforce_non_empty_field(node, op_type); |
612 | |
613 | if (node.input().empty() && node.output().empty()) { |
614 | fail_check("NodeProto (name: " , node.name(), ", type: " , node.op_type(), ") has zero input and zero output." ); |
615 | } |
616 | |
617 | // Resolve domain for node |
618 | const auto& opset_imports = ctx.get_opset_imports(); |
619 | auto dit = opset_imports.find(node.domain()); |
620 | if (dit == opset_imports.end()) { |
621 | fail_check("No opset import for domain '" + node.domain() + "'" ); |
622 | } |
623 | auto domain_version = dit->second; |
624 | |
625 | for (const auto& attr : node.attribute()) { |
626 | check_attribute(attr, ctx, lex_ctx); |
627 | } |
628 | |
629 | // This issue will be caught by check_graph instead |
630 | if (check_is_experimental_op(node)) { |
631 | return; |
632 | } |
633 | |
634 | const auto* schema = ctx.get_schema_registry()->GetSchema(node.op_type(), domain_version, node.domain()); |
635 | if (!schema) { |
636 | if (node.domain() == ONNX_DOMAIN || node.domain() == AI_ONNX_ML_DOMAIN || node.domain() == "ai.onnx" || |
637 | node.domain() == AI_ONNX_TRAINING_DOMAIN) { |
638 | // fail the checker if op in built-in domains has no schema |
639 | fail_check( |
640 | "No Op registered for " + node.op_type() + " with domain_version of " + |
641 | ONNX_NAMESPACE::to_string(domain_version)); |
642 | } else { |
643 | // TODO: expose the registration of the op schemas appropriately in |
644 | // python, so we can load and register operators in other domains |
645 | // |
646 | // before we complete the above todo, let's skip the schema check for |
647 | // now |
648 | } |
649 | } else if (schema->Deprecated()) { |
650 | fail_check( |
651 | "Op registered for " + node.op_type() + " is deprecated in domain_version of " + |
652 | ONNX_NAMESPACE::to_string(domain_version)); |
653 | } else { |
654 | schema->Verify(node); |
655 | } |
656 | } |
657 | |
658 | void check_graph(const GraphProto& graph, const CheckerContext& ctx, const LexicalScopeContext& parent_lex) { |
659 | enforce_non_empty_field(graph, name); |
660 | |
661 | for (const auto& value_info : graph.input()) { |
662 | check_value_info(value_info, ctx); |
663 | } |
664 | for (const auto& value_info : graph.output()) { |
665 | check_value_info(value_info, ctx); |
666 | } |
667 | |
668 | // Inherit values available in outer scope |
669 | // Note that we do not allow shadowing, so the presence of an already-defined |
670 | // name is always an error. |
671 | LexicalScopeContext lex_ctx{parent_lex}; |
672 | |
673 | for (const auto& value_info : graph.input()) { |
674 | // TODO: If shadowing isn't allowed, this should maybe use |
675 | // this_or_ancestor_graph_has |
676 | if (lex_ctx.this_graph_has(value_info.name())) { |
677 | fail_check( |
678 | "Graph must be in single static assignment (SSA) form, however '" , |
679 | value_info.name(), |
680 | "' has been used as graph input names multiple times." ); |
681 | } |
682 | lex_ctx.add(value_info.name()); |
683 | } |
684 | |
685 | std::unordered_set<std::reference_wrapper<const std::string>, std::hash<std::string>, std::equal_to<std::string>> |
686 | initializer_name_checker; |
687 | |
688 | for (const auto& init : graph.initializer()) { |
689 | enforce_has_field(init, name); |
690 | const auto& name = init.name(); |
691 | if (name.empty()) { |
692 | fail_check("Tensor initializers must have a non-empty name" ); |
693 | } |
694 | |
695 | if (!initializer_name_checker.insert(std::cref(name)).second) { |
696 | fail_check(name + " initializer name is not unique" ); |
697 | } |
698 | |
699 | check_tensor(init, ctx); |
700 | |
701 | if (ctx.get_ir_version() <= 0x00000003) { |
702 | // Initializers are a subset of graph inputs for IR_VERSION <= 3 |
703 | if (!lex_ctx.this_graph_has(name)) { |
704 | fail_check(name + " in initializer but not in graph input" ); |
705 | } |
706 | } else { |
707 | // An initializer is allowed to have the same name as an input, |
708 | // but is not required to (for IR_VERSION >= 4) |
709 | lex_ctx.add(name); |
710 | } |
711 | } |
712 | |
713 | for (const auto& sparse_init : graph.sparse_initializer()) { |
714 | const auto& values = sparse_init.values(); |
715 | enforce_has_field(values, name); |
716 | const auto& name = values.name(); |
717 | if (name.empty()) { |
718 | fail_check("Sparse tensor initializers must have a non-empty name" ); |
719 | } |
720 | if (!initializer_name_checker.insert(std::cref(name)).second) { |
721 | fail_check(name + " sparse initializer name is not unique across initializers and sparse_initializers" ); |
722 | } |
723 | check_sparse_tensor(sparse_init, ctx); |
724 | lex_ctx.add(name); |
725 | } |
726 | std::unordered_set<std::string> used_experimental_ops; |
727 | for (const auto& node : graph.node()) { |
728 | // nodes must be in topologically sorted order |
729 | for (const auto& input : node.input()) { |
730 | // explicit optional input |
731 | if (input.empty()) { |
732 | continue; |
733 | } |
734 | if (!lex_ctx.this_or_ancestor_graph_has(input)) { |
735 | fail_check( |
736 | "Nodes in a graph must be topologically sorted, however input '" , |
737 | input, |
738 | "' of node: \n" , |
739 | "name: " , |
740 | node.name(), |
741 | " OpType: " , |
742 | node.op_type(), |
743 | "\n is not output of any previous nodes." ); |
744 | } |
745 | } |
746 | |
747 | if (check_is_experimental_op(node)) { |
748 | used_experimental_ops.insert(node.op_type()); |
749 | } |
750 | |
751 | // This needs to happen before SSA check since we don't want to recurse and |
752 | // find that outputs from control flow ops are colliding with names in the |
753 | // inner block |
754 | |
755 | ONNX_TRY { |
756 | check_node(node, ctx, lex_ctx); |
757 | } |
758 | ONNX_CATCH(ValidationError & ex) { |
759 | ONNX_HANDLE_EXCEPTION([&]() { |
760 | ex.AppendContext("Bad node spec for node. Name: " + node.name() + " OpType: " + node.op_type()); |
761 | ONNX_THROW_EX(ex); |
762 | }); |
763 | } |
764 | // check for SSA form |
765 | for (const auto& output : node.output()) { |
766 | // optional output |
767 | if (output.empty()) { |
768 | continue; |
769 | } |
770 | |
771 | if (lex_ctx.this_or_ancestor_graph_has(output)) { |
772 | fail_check( |
773 | "Graph must be in single static assignment (SSA) form, however '" , |
774 | output, |
775 | "' has been used as output names multiple times." ); |
776 | } |
777 | lex_ctx.add(output); |
778 | } |
779 | } |
780 | print_warning_if_has_experimental(used_experimental_ops); |
781 | } |
782 | |
783 | // Utilify function to get the imported version of domain from opset imports |
784 | // Returns -1 if requested domain is not found in the opset_imports |
785 | int get_version_for_domain(const std::string& domain, const std::unordered_map<std::string, int>& opset_imports) { |
786 | auto it = opset_imports.find(domain); |
787 | if (it == opset_imports.end()) { |
788 | return -1; |
789 | } |
790 | |
791 | return it->second; |
792 | } |
793 | |
794 | void check_opset_compatibility( |
795 | const NodeProto& node, |
796 | const CheckerContext& ctx, |
797 | const std::unordered_map<std::string, int>& func_opset_imports, |
798 | const std::unordered_map<std::string, int>& model_opset_imports) { |
799 | auto func_opset_version = get_version_for_domain(node.domain(), func_opset_imports); |
800 | auto model_opset_version = get_version_for_domain(node.domain(), model_opset_imports); |
801 | |
802 | if (func_opset_version == -1) { |
803 | fail_check("No Opset registered for domain " + node.domain()); |
804 | } |
805 | |
806 | if (model_opset_version == -1) { |
807 | // model does not include opset import for a node present in function body. |
808 | // This is ok as along as the opset import is present in function level opset imports. |
809 | return; |
810 | } |
811 | |
812 | if (func_opset_version == model_opset_version) { |
813 | // both versions are same, no need to verify schema. |
814 | return; |
815 | } |
816 | |
817 | const auto* schema_for_model_import = |
818 | ctx.get_schema_registry()->GetSchema(node.op_type(), model_opset_version, node.domain()); |
819 | |
820 | const auto* schema_for_function_import = |
821 | ctx.get_schema_registry()->GetSchema(node.op_type(), func_opset_version, node.domain()); |
822 | |
823 | if (!schema_for_model_import && !schema_for_function_import) { |
824 | // the op belongs to a custom domain so we cannot verify schema |
825 | return; |
826 | } |
827 | |
828 | // if schema is present for 1 but not other or the schema since versions do not match then raise an error |
829 | if (!schema_for_model_import || !schema_for_function_import || |
830 | schema_for_function_import->since_version() != schema_for_model_import->since_version()) { |
831 | fail_check( |
832 | "Opset import for domain " + node.domain() + " in function op " + node.op_type() + |
833 | "is not compatible with the version imported by model. FunctionOp imports version " + |
834 | ONNX_NAMESPACE::to_string(func_opset_version) + "whereas model imports version " + |
835 | ONNX_NAMESPACE::to_string(model_opset_version)); |
836 | } |
837 | } |
838 | |
839 | void check_model_local_functions( |
840 | const ModelProto& model, |
841 | const CheckerContext& ctx, |
842 | const LexicalScopeContext& parent_lex) { |
843 | // make a copy of model opset imports to maintain a main copy of opset imports across the model and |
844 | // all model local functions to verify opset compatibility |
845 | std::unordered_map<std::string, int> model_opset_imports(ctx.get_opset_imports()); |
846 | |
847 | // merge the opset imports from every function in model_opset_imports |
848 | // only add the opset import if an entry for it does not exist in model_opset_imports |
849 | // if there is an entry then the compatibility will be checked later on in check_opset_compatibility |
850 | // called by check_function. |
851 | for (const auto& function_proto : model.functions()) { |
852 | for (const auto& opset_import : function_proto.opset_import()) { |
853 | if (get_version_for_domain(opset_import.domain(), model_opset_imports) == -1) { |
854 | model_opset_imports[opset_import.domain()] = opset_import.version(); |
855 | } |
856 | } |
857 | } |
858 | |
859 | CheckerContext ctx_copy = ctx; |
860 | ctx_copy.set_opset_imports(model_opset_imports); |
861 | |
862 | for (const auto& function_proto : model.functions()) { |
863 | check_function(function_proto, ctx_copy, parent_lex); |
864 | } |
865 | } |
866 | |
867 | void check_function(const FunctionProto& function, const CheckerContext& ctx, const LexicalScopeContext& parent_lex) { |
868 | enforce_non_empty_field(function, name); |
869 | |
870 | if (ctx.get_ir_version() >= 0x00000008) { |
871 | enforce_has_field(function, domain); |
872 | } |
873 | |
874 | const auto& model_opset_imports = ctx.get_opset_imports(); |
875 | CheckerContext ctx_copy = ctx; |
876 | |
877 | std::unordered_map<std::string, int> func_opset_imports; |
878 | for (auto& relied_opset : function.opset_import()) { |
879 | func_opset_imports[relied_opset.domain()] = static_cast<int>(relied_opset.version()); |
880 | } |
881 | |
882 | ctx_copy.set_opset_imports(func_opset_imports); |
883 | |
884 | LexicalScopeContext lex_ctx{parent_lex}; |
885 | |
886 | for (const auto& input : function.input()) { |
887 | // TODO: If shadowing isn't allowed, this should maybe use |
888 | // this_or_ancestor_graph_has |
889 | if (lex_ctx.this_graph_has(input)) { |
890 | fail_check( |
891 | "Graph must be in single static assignment (SSA) form, however '" , input, "' has been used multiple times." ); |
892 | } |
893 | lex_ctx.add(input); |
894 | } |
895 | |
896 | std::unordered_set<std::string> outputs; |
897 | for (const auto& output : function.output()) { |
898 | auto result = outputs.insert(output); |
899 | if (!result.second) { |
900 | fail_check("function (" , function.name(), ") should not have duplicate outputs specified." ); |
901 | } |
902 | } |
903 | |
904 | std::unordered_set<std::string> attrs; |
905 | for (const auto& attr : function.attribute()) { |
906 | auto result = attrs.insert(attr); |
907 | if (!result.second) { |
908 | fail_check("function (" , function.name(), ") should not have duplicate attributes specified." ); |
909 | } |
910 | } |
911 | std::unordered_set<std::string> used_experimental_ops; |
912 | for (const auto& node : function.node()) { |
913 | // nodes must be in topologically sorted order |
914 | for (const auto& input : node.input()) { |
915 | // explicit optional input |
916 | if (input.empty()) { |
917 | continue; |
918 | } |
919 | if (!lex_ctx.this_graph_has(input)) { |
920 | fail_check( |
921 | "Nodes in a function must be topologically sorted, however input '" , |
922 | input, |
923 | "' of node: \n" , |
924 | "Name: " , |
925 | node.name(), |
926 | " OpType: " , |
927 | node.op_type(), |
928 | "\n is neither output of any previous nodes nor input of the function." ); |
929 | } |
930 | } |
931 | |
932 | // check whether the opset version imported for a domain by function and model are |
933 | // compatible |
934 | check_opset_compatibility(node, ctx_copy, func_opset_imports, model_opset_imports); |
935 | if (check_is_experimental_op(node)) { |
936 | used_experimental_ops.insert(node.op_type()); |
937 | } |
938 | check_node(node, ctx_copy, lex_ctx); |
939 | |
940 | // check for SSA form |
941 | for (const auto& output : node.output()) { |
942 | // optional output |
943 | if (output.empty()) { |
944 | continue; |
945 | } |
946 | if (lex_ctx.this_or_ancestor_graph_has(output)) { |
947 | fail_check( |
948 | "Function must be in single static assignment (SSA) form, however '" , |
949 | output, |
950 | "' has been used as output names multiple times." ); |
951 | } |
952 | lex_ctx.add(output); |
953 | } |
954 | } |
955 | print_warning_if_has_experimental(used_experimental_ops); |
956 | } |
957 | |
958 | void check_model(const ModelProto& model, CheckerContext& ctx) { |
959 | if (!model.ir_version()) { |
960 | fail_check("The model does not have an ir_version set properly." ); |
961 | } |
962 | if (model.ir_version() > IR_VERSION) { |
963 | fail_check("Your model ir_version is higher than the checker's." ); |
964 | } |
965 | if (model.metadata_props_size() > 1) { |
966 | std::unordered_set<std::string> keys; |
967 | for (const StringStringEntryProto& entry : model.metadata_props()) { |
968 | auto i = keys.insert(entry.key()); |
969 | if (!i.second) { |
970 | fail_check("Your model has duplicate keys in metadata_props." ); |
971 | } |
972 | } |
973 | } |
974 | std::unordered_map<std::string, int> versions; |
975 | ctx.set_ir_version(static_cast<int>(model.ir_version())); |
976 | std::unordered_map<std::string, int> opset_imports; |
977 | for (const auto& opset_import : model.opset_import()) { |
978 | opset_imports[opset_import.domain()] = static_cast<int>(opset_import.version()); |
979 | } |
980 | if (model.ir_version() >= 3) { |
981 | if (opset_imports.empty()) { |
982 | fail_check("model with IR version >= 3 must specify opset_import for ONNX" ); |
983 | } |
984 | } else { |
985 | if (opset_imports.empty()) |
986 | opset_imports[ONNX_DOMAIN] = 1; |
987 | else { |
988 | fail_check("model with IR version < 3 cannot have opset_import specified" ); |
989 | } |
990 | } |
991 | ctx.set_opset_imports(opset_imports); |
992 | LexicalScopeContext lex_ctx; |
993 | check_graph(model.graph(), ctx, lex_ctx); |
994 | |
995 | if (ctx.get_ir_version() >= 0x00000008) { |
996 | check_model_local_functions(model, ctx, lex_ctx); |
997 | } |
998 | } |
999 | |
1000 | void check_model(const std::string& model_path, bool full_check) { |
1001 | ModelProto model; |
1002 | LoadProtoFromPath(model_path, model); |
1003 | |
1004 | CheckerContext ctx; |
1005 | std::string model_dir; |
1006 | size_t pos = model_path.find_last_of("\\/" ); |
1007 | if (pos != std::string::npos) { |
1008 | model_dir = model_path.substr(0, pos + 1); |
1009 | } |
1010 | ctx.set_model_dir(model_dir); |
1011 | check_model(model, ctx); |
1012 | |
1013 | if (full_check) { |
1014 | ShapeInferenceOptions options{true, 1, false}; |
1015 | ONNX_NAMESPACE::shape_inference::InferShapes(model, ctx.get_schema_registry(), options); |
1016 | } |
1017 | } |
1018 | |
1019 | void check_model(const ModelProto& model, bool full_check) { |
1020 | CheckerContext ctx; |
1021 | check_model(model, ctx); |
1022 | if (full_check) { |
1023 | ShapeInferenceOptions options{true, 1, false}; |
1024 | // Do not update the model in place by the check from shape inference |
1025 | // because checker should not modify the original model |
1026 | ModelProto copy = model; |
1027 | ONNX_NAMESPACE::shape_inference::InferShapes(copy, ctx.get_schema_registry(), options); |
1028 | } |
1029 | } |
1030 | |
1031 | std::set<std::string> experimental_ops = { |
1032 | "ATen" , |
1033 | "Affine" , |
1034 | "ConstantFill" , |
1035 | "Crop" , |
1036 | "DynamicSlice" , |
1037 | "GRUUnit" , |
1038 | "GivenTensorFill" , |
1039 | "ImageScaler" , |
1040 | "ParametricSoftplus" , |
1041 | "Scale" , |
1042 | "ScaledTanh" }; |
1043 | |
1044 | bool check_is_experimental_op(const NodeProto& node) { |
1045 | return (node.domain() == ONNX_DOMAIN || node.domain() == "ai.onnx" ) && experimental_ops.count(node.op_type()); |
1046 | } |
1047 | |
1048 | #undef fail_check |
1049 | #undef enforce_has_field |
1050 | #undef enforce_has_repeated_field |
1051 | #undef enforce_non_empty_field |
1052 | |
1053 | } // namespace checker |
1054 | } // namespace ONNX_NAMESPACE |
1055 | |