1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "onnx/checker.h"
6#include "onnx/common/file_utils.h"
7#include "onnx/common/path.h"
8#include "onnx/defs/schema.h"
9#include "onnx/defs/tensor_proto_util.h"
10#include "onnx/proto_utils.h"
11#include "onnx/shape_inference/implementation.h"
12#include "onnx/string_utils.h"
13
14#include <fstream>
15#include <iterator>
16#include <unordered_set>
17
18#ifdef _WIN32
19#include <direct.h>
20#include <filesystem>
21
22#else // POSIX
23#include <sys/stat.h>
24#endif
25
26namespace ONNX_NAMESPACE {
27namespace checker {
28
29#define enforce_has_field(proto, field) \
30 do { \
31 if (!proto.has_##field()) { \
32 fail_check("Field '", #field, "' of '", #proto, "' is required but missing."); \
33 } \
34 } while (0)
35
36#define enforce_has_repeated_field(proto, field) \
37 do { \
38 if (!proto.field##_size()) { \
39 fail_check("Repeated Field '", #field, "' of '", #proto, "' is required but missing."); \
40 } \
41 } while (0)
42
43#define enforce_non_empty_field(proto, field) \
44 do { \
45 if (proto.field().empty()) { \
46 fail_check("Field '", #field, "' of '", #proto, "' is required to be non-empty."); \
47 } \
48 } while (0)
49
50void check_value_info(const ValueInfoProto& value_info, const CheckerContext& ctx) {
51 enforce_non_empty_field(value_info, name);
52 // Relax constraint for subgraph input/output.
53 if (!ctx.is_main_graph())
54 return;
55 enforce_has_field(value_info, type);
56 const auto value_case = value_info.type().value_case();
57 switch (value_case) {
58 case TypeProto::kTensorType: {
59 const auto& type = value_info.type().tensor_type();
60 enforce_has_field(type, elem_type);
61 enforce_has_field(type, shape);
62 } break;
63 case TypeProto::kOptionalType: {
64 const auto& type = value_info.type().optional_type();
65 enforce_has_field(type, elem_type);
66 } break;
67 case TypeProto::kSequenceType: {
68 const auto& type = value_info.type().sequence_type();
69 enforce_has_field(type, elem_type);
70 } break;
71 case TypeProto::kMapType: {
72 const auto& type = value_info.type().map_type();
73 enforce_has_field(type, key_type);
74 enforce_has_field(type, value_type);
75 } break;
76#ifdef ONNX_ML
77 case TypeProto::kOpaqueType:
78 break;
79#endif
80 case TypeProto::kSparseTensorType: {
81 const auto& type = value_info.type().sparse_tensor_type();
82 enforce_has_field(type, elem_type);
83 enforce_has_field(type, shape);
84 } break;
85
86 default:
87 fail_check("Unrecognized type value case (value_info name: ", value_info.name(), "): ", value_case);
88 }
89}
90
91void check_tensor(const TensorProto& tensor, const CheckerContext& ctx) {
92 enforce_has_field(tensor, data_type);
93 if (tensor.data_type() == TensorProto::UNDEFINED) {
94 fail_check("setting data_type field (tensor name: ", tensor.name(), ") to UNDEFINED is not allowed");
95 }
96
97 int num_value_fields = 0;
98
99 const char* value_field = nullptr;
100
101#define check_data_field(field) \
102 bool has_##field = tensor.field().size(); \
103 if (has_##field) { \
104 ++num_value_fields; \
105 value_field = #field; \
106 }
107
108 check_data_field(float_data);
109 check_data_field(int32_data);
110 check_data_field(string_data);
111 check_data_field(int64_data);
112 check_data_field(raw_data);
113 check_data_field(double_data);
114 check_data_field(uint64_data);
115
116#undef check_data_field
117
118 bool stored_externally = tensor.has_data_location() && tensor.data_location() == TensorProto::EXTERNAL;
119 if (stored_externally) {
120 if (num_value_fields != 0) {
121 fail_check(
122 "Data of TensorProto ( tensor name: ",
123 tensor.name(),
124 ") is stored externally and should not have data field.",
125 value_field);
126 }
127
128 bool has_location = false;
129 for (const StringStringEntryProto& entry : tensor.external_data()) {
130 if (entry.has_key() && entry.has_value() && entry.key() == "location") {
131 has_location = true;
132#ifdef _WIN32
133 auto file_path = std::filesystem::path(utf8str_to_wstring(entry.value()));
134 if (file_path.is_absolute()) {
135 fail_check(
136 "Location of external TensorProto ( tensor name: ",
137 tensor.name(),
138 ") should be a relative path, but it is an absolute path: ",
139 entry.value());
140 }
141 auto relative_path = file_path.lexically_normal().make_preferred().wstring();
142 // Check that normalized relative path contains ".." on Windows.
143 if (relative_path.find(L"..", 0) != std::string::npos) {
144 fail_check(
145 "Data of TensorProto ( tensor name: ",
146 tensor.name(),
147 ") should be file inside the ",
148 ctx.get_model_dir(),
149 ", but the '",
150 entry.value(),
151 "' points outside the directory");
152 }
153 std::wstring data_path = path_join(utf8str_to_wstring(ctx.get_model_dir()), relative_path);
154 struct _stat buff;
155 if (_wstat(data_path.c_str(), &buff) != 0) {
156 fail_check(
157 "Data of TensorProto ( tensor name: ",
158 tensor.name(),
159 ") should be stored in ",
160 entry.value(),
161 ", but it doesn't exist or is not accessible.");
162 }
163#else // POSIX
164 if (entry.value().empty()) {
165 fail_check("Location of external TensorProto ( tensor name: ", tensor.name(), ") should not be empty.");
166 } else if (entry.value()[0] == '/') {
167 fail_check(
168 "Location of external TensorProto ( tensor name: ",
169 tensor.name(),
170 ") should be a relative path, but it is an absolute path: ",
171 entry.value());
172 }
173 std::string relative_path = clean_relative_path(entry.value());
174 // Check that normalized relative path contains ".." on POSIX
175 if (relative_path.find("..", 0) != std::string::npos) {
176 fail_check(
177 "Data of TensorProto ( tensor name: ",
178 tensor.name(),
179 ") should be file inside the ",
180 ctx.get_model_dir(),
181 ", but the '",
182 entry.value(),
183 "' points outside the directory");
184 }
185 std::string data_path = path_join(ctx.get_model_dir(), relative_path);
186 // use stat to check whether the file exists
187 struct stat buffer;
188 if (stat((data_path).c_str(), &buffer) != 0) {
189 fail_check(
190 "Data of TensorProto ( tensor name: ",
191 tensor.name(),
192 ") should be stored in ",
193 data_path,
194 ", but it doesn't exist or is not accessible.");
195 }
196 // Do not allow symlinks or directories.
197 if (!S_ISREG(buffer.st_mode)) {
198 fail_check(
199 "Data of TensorProto ( tensor name: ",
200 tensor.name(),
201 ") should be stored in ",
202 data_path,
203 ", but it is not regular file.");
204 }
205#endif
206 }
207 }
208 if (!has_location) {
209 fail_check("TensorProto ( tensor name: ", tensor.name(), ") is stored externally but doesn't have a location.");
210 }
211 return;
212 }
213 int64_t nelem = 1;
214 for (auto x : tensor.dims()) {
215 nelem *= x;
216 }
217 if (nelem == 0 && num_value_fields != 0) {
218 fail_check("TensorProto (tensor name: ", tensor.name(), ") is 0-element but contains data!");
219 }
220 if (nelem != 0 && num_value_fields != 1) {
221 fail_check("TensorProto (tensor name: ", tensor.name(), ") should contain one and only one value field.");
222 }
223 if (has_raw_data) {
224 if (tensor.data_type() == TensorProto::STRING) {
225 fail_check("STRING data (tensor name: ", tensor.name(), ") should not be stored in raw_data field");
226 }
227 return;
228 } else {
229#define check_field(field) \
230 if (nelem != 0 && !has_##field) { \
231 fail_check( \
232 "values of data_type '", \
233 tensor.data_type(), \
234 "' should be stored in field '", \
235 #field, \
236 "' instead of '", \
237 value_field, \
238 "'"); \
239 }
240
241 switch (tensor.data_type()) {
242 case TensorProto::FLOAT:
243 case TensorProto::COMPLEX64:
244 check_field(float_data);
245 break;
246
247 case TensorProto::DOUBLE:
248 case TensorProto::COMPLEX128:
249 check_field(double_data);
250 break;
251
252 case TensorProto::INT32:
253 case TensorProto::UINT8:
254 case TensorProto::INT8:
255 case TensorProto::UINT16:
256 case TensorProto::INT16:
257 case TensorProto::BOOL:
258 case TensorProto::FLOAT16:
259 case TensorProto::BFLOAT16:
260 check_field(int32_data);
261 break;
262
263 case TensorProto::INT64:
264 check_field(int64_data);
265 break;
266
267 case TensorProto::UINT32:
268 case TensorProto::UINT64:
269 check_field(uint64_data);
270 break;
271
272 case TensorProto::STRING:
273 check_field(string_data);
274 break;
275
276 default:
277 fail_check("Unrecognized data_type (tensor name: ", tensor.name(), "): ", tensor.data_type());
278 }
279 }
280
281#undef check_field
282}
283
284void check_sequence(const SequenceProto& sequence, const CheckerContext& ctx) {
285 enforce_has_field(sequence, elem_type);
286 if (sequence.elem_type() == SequenceProto::TENSOR) {
287 for (const TensorProto& tensor : sequence.tensor_values()) {
288 check_tensor(tensor, ctx);
289 }
290 } else if (sequence.elem_type() == SequenceProto::SPARSE_TENSOR) {
291 for (const SparseTensorProto& sparse_tensor : sequence.sparse_tensor_values()) {
292 check_sparse_tensor(sparse_tensor, ctx);
293 }
294 } else if (sequence.elem_type() == SequenceProto::SEQUENCE) {
295 for (const SequenceProto& seq : sequence.sequence_values()) {
296 check_sequence(seq, ctx);
297 }
298 } else if (sequence.elem_type() == SequenceProto::MAP) {
299 for (const MapProto& map : sequence.map_values()) {
300 check_map(map, ctx);
301 }
302 } else {
303 fail_check(
304 "Sequence ( Structure name: ",
305 sequence.name(),
306 ", elem_type: ",
307 sequence.elem_type(),
308 ") is not have a valid element type.");
309 }
310}
311
312void check_optional(const OptionalProto& optional, const CheckerContext& ctx) {
313 enforce_has_field(optional, elem_type);
314 if (optional.elem_type() == OptionalProto::UNDEFINED) {
315 return;
316 } else if (optional.elem_type() == OptionalProto::TENSOR) {
317 if (optional.has_tensor_value())
318 check_tensor(optional.tensor_value(), ctx);
319 } else if (optional.elem_type() == OptionalProto::SPARSE_TENSOR) {
320 if (optional.has_sparse_tensor_value())
321 check_sparse_tensor(optional.sparse_tensor_value(), ctx);
322 } else if (optional.elem_type() == OptionalProto::SEQUENCE) {
323 if (optional.has_sequence_value())
324 check_sequence(optional.sequence_value(), ctx);
325 } else if (optional.elem_type() == OptionalProto::MAP) {
326 if (optional.has_map_value())
327 check_map(optional.map_value(), ctx);
328 } else {
329 fail_check(
330 "Optional ( Structure name: ",
331 optional.name(),
332 ", elem_type: ",
333 optional.elem_type(),
334 ") is not have a valid element type.");
335 }
336}
337
338void check_map(const MapProto& map, const CheckerContext& ctx) {
339 enforce_has_field(map, key_type);
340 if (map.key_type() == TensorProto::UNDEFINED) {
341 fail_check("setting key_type field (map name: ", map.name(), ") to UNDEFINED is not allowed");
342 }
343 // Check if key is a valid type, specifically INT8, INT16, INT32, INT64,
344 // UINT8, UINT16, UINT32, UINT64, or STRING.
345 if ((map.key_type() == TensorProto::FLOAT) || (map.key_type() == TensorProto::BOOL) ||
346 (map.key_type() == TensorProto::FLOAT16) || (map.key_type() == TensorProto::COMPLEX64) ||
347 (map.key_type() == TensorProto::COMPLEX128)) {
348 fail_check(
349 "setting key_type field (map name: ",
350 map.name(),
351 ") to invalid TensorProto key_type ",
352 map.key_type(),
353 " is not allowed");
354 }
355
356 // MapProto will use either keys or string_keys, so only one should be > 0.
357 if ((map.keys_size() > 0) && (map.string_keys_size() > 0)) {
358 fail_check("Map (name: ", map.name(), ") should not contain more than one keys field.");
359 }
360
361 int num_keys = map.keys_size() + map.string_keys_size();
362 int num_values = 0;
363
364 enforce_has_field(map, values);
365 check_sequence(map.values(), ctx);
366
367 if (map.values().elem_type() == SequenceProto::TENSOR) {
368 num_values = map.values().tensor_values_size();
369 } else if (map.values().elem_type() == SequenceProto::SPARSE_TENSOR) {
370 num_values = map.values().sparse_tensor_values_size();
371 } else if (map.values().elem_type() == SequenceProto::SEQUENCE) {
372 num_values = map.values().sequence_values_size();
373 } else if (map.values().elem_type() == SequenceProto::MAP) {
374 num_values = map.values().map_values_size();
375 }
376
377 if (num_keys != num_values) {
378 fail_check("Length of map keys and map values are not the same (map name: ", map.name(), ")");
379 }
380}
381
382// Check that the index data stored in a SparseTensorProto is valid.
383// indices: a 1-dimensional tensor; indices[i] represents the
384// linearized index value for the i-th nonzero value.
385void check_sparse_tensor_indices_1(
386 const TensorProto& indices,
387 const SparseTensorProto& sparse_tensor_proto,
388 size_t nnz) {
389 int dense_rank = sparse_tensor_proto.dims_size();
390 int64_t dense_size = 1;
391 for (int i = 0; i < dense_rank; ++i)
392 dense_size *= sparse_tensor_proto.dims(i);
393 if (static_cast<size_t>(indices.dims(0)) != nnz) {
394 fail_check("Sparse tensor indices (", indices.name(), ") has ", indices.dims(0), " values, but NNZ is ", nnz);
395 }
396
397 // Check if indices appear in ascending order, and if they have valid
398 // values. The i-th value in index_data is the linear index of the i-th
399 // non-zero value.
400 const std::vector<int64_t> index_data = ParseData<int64_t>(&indices);
401
402 int64_t prev_index = -1;
403 for (size_t i = 0; i < nnz; ++i) {
404 int64_t curr_index = index_data[i]; // linearized index of i-th value
405 if (curr_index < 0 || curr_index >= dense_size) {
406 fail_check(
407 "Sparse tensor (",
408 indices.name(),
409 ") index value at position [",
410 i,
411 "] out of range [0, ",
412 dense_size - 1,
413 "]");
414 }
415 if (curr_index <= prev_index) {
416 fail_check("Sparse tensor (", indices.name(), ") index value at position [", i, "] not in sorted order.");
417 }
418 prev_index = curr_index;
419 }
420}
421
422// Check that the index data stored in a SparseTensorProto is valid.
423// indices: a 2-dimensional tensor; indices[i,j] represents the j-th
424// index value for the i-th nonzero value.
425void check_sparse_tensor_indices_2(
426 const TensorProto& indices,
427 const SparseTensorProto& sparse_tensor_proto,
428 size_t nnz) {
429 int dense_rank = sparse_tensor_proto.dims_size();
430 if (static_cast<size_t>(indices.dims(0)) != nnz) {
431 fail_check("Sparse tensor indices (", indices.name(), ") first dimension size does not equal NNZ.");
432 }
433 if (indices.dims(1) != dense_rank) {
434 fail_check("Sparse tensor indices (", indices.name(), ") second dimension size does not match rank of tensor.");
435 }
436
437 // Check if indices appear in ascending order, and if they have valid
438 // values.
439 const std::vector<int64_t> index_data = ParseData<int64_t>(&indices);
440 int64_t prev_index = -1;
441 for (size_t i = 0; i < nnz; ++i) {
442 int64_t curr_index = 0; // linearized index of i-th value
443 for (int j = 0; j < dense_rank; ++j) {
444 auto index_ij = index_data[i * dense_rank + j];
445 if ((index_ij < 0) || (index_ij >= sparse_tensor_proto.dims(j))) {
446 fail_check("Sparse tensor (", indices.name(), ") index value at position [", i, ",", j, "] out of range.");
447 }
448 curr_index = curr_index * sparse_tensor_proto.dims(j) + index_ij;
449 }
450 if (curr_index <= prev_index) {
451 fail_check(
452 "Sparse tensor (", indices.name(), ") index value at position [", i, "] not in lexicographic sorted order.");
453 }
454 prev_index = curr_index;
455 }
456}
457
458void check_sparse_tensor(const SparseTensorProto& sparse_tensor_proto, const CheckerContext& ctx) {
459 enforce_has_field(sparse_tensor_proto, values);
460
461 const TensorProto& values = sparse_tensor_proto.values();
462 check_tensor(values, ctx);
463
464 // values must be a tensor of shape [NNZ]
465 // Currently we restrict the value associated with a particular index-tuple
466 // to be a single value. In the future, if there is a requirement,
467 // we may extend this to permit the value to be a "sub-tensor", in which
468 // case values will have dimension > 1.
469 if (values.dims_size() != 1) {
470 fail_check("Sparse tensor values (", values.name(), ") must have rank 1.");
471 }
472 size_t nnz = static_cast<size_t>(values.dims(0));
473 int dense_rank = sparse_tensor_proto.dims_size();
474 if (dense_rank == 0) {
475 fail_check("Sparse tensor (", values.name(), ") must have a dense-rank > 0");
476 }
477 for (int i = 0; i < dense_rank; ++i) {
478 if (sparse_tensor_proto.dims(i) <= 0) {
479 fail_check("Sparse tensor (", values.name(), ") dimensions are not positive.");
480 }
481 }
482
483 if (sparse_tensor_proto.has_indices()) {
484 const TensorProto& indices = sparse_tensor_proto.indices();
485 check_tensor(indices, ctx);
486 if (indices.data_type() != TensorProto::INT64) {
487 fail_check("Sparse tensor indices (", indices.name(), ") must have INT64 type.");
488 }
489 switch (indices.dims().size()) {
490 case 1:
491 // Indices in linearized format
492 check_sparse_tensor_indices_1(indices, sparse_tensor_proto, nnz);
493 return;
494 case 2:
495 // Check COO-style index. E.g., an index for a 3D tensor is a 3-tuple.
496 check_sparse_tensor_indices_2(indices, sparse_tensor_proto, nnz);
497 return;
498 default:
499 fail_check("Sparse tensor indices (", indices.name(), ") must have rank 1 or 2.");
500 }
501 } else if (nnz != 0) {
502 fail_check("Sparse tensor (", values.name(), ") has no index values.");
503 }
504}
505
506// NB: This is a generic "attribute well-formedness" check, it doesn't
507// actually test if an attribute is valid per a schema
508void check_attribute(const AttributeProto& attr, const CheckerContext& ctx, const LexicalScopeContext& lex_ctx) {
509 enforce_non_empty_field(attr, name);
510
511 if (ctx.get_ir_version() >= 0x00000002) {
512 enforce_has_field(attr, type);
513 }
514
515 int used_fields = 0;
516
517#define check_type(expected_type) \
518 if (attr.has_type() && attr.type() != expected_type) { \
519 fail_check("type field and data field mismatch in attribute ", attr.name(), "."); \
520 }
521
522#define check_singular_field(field, type) \
523 if (attr.has_##field()) { \
524 ++used_fields; \
525 check_type(type); \
526 }
527
528#define check_repeated_field(field, type) \
529 if (attr.field##_size() > 0) { \
530 ++used_fields; \
531 check_type(type); \
532 }
533
534 check_singular_field(f, AttributeProto::FLOAT);
535 check_singular_field(i, AttributeProto::INT);
536 check_singular_field(s, AttributeProto::STRING);
537 check_singular_field(t, AttributeProto::TENSOR);
538 check_singular_field(g, AttributeProto::GRAPH);
539 check_singular_field(tp, AttributeProto::TYPE_PROTO);
540 check_singular_field(sparse_tensor, AttributeProto::SPARSE_TENSOR);
541 check_repeated_field(floats, AttributeProto::FLOATS);
542 check_repeated_field(ints, AttributeProto::INTS);
543 check_repeated_field(strings, AttributeProto::STRINGS);
544 check_repeated_field(tensors, AttributeProto::TENSORS);
545 check_repeated_field(graphs, AttributeProto::GRAPHS);
546 check_repeated_field(sparse_tensors, AttributeProto::SPARSE_TENSORS);
547 check_repeated_field(type_protos, AttributeProto::TYPE_PROTOS);
548
549#undef check_type
550#undef check_singular_field
551#undef check_repeated_field
552
553 // Normally, used_fields is expected to be 1.
554 // In proto3, when the value to be set is type default value (say 0 for
555 // int), used_fields may be 0.
556 if (used_fields > 1) {
557 fail_check("Attribute (name: ", attr.name(), ") should not contain more than one value field.");
558 }
559
560 if (!ctx.is_main_graph()) {
561 // It's an attribute of a node in function body.
562 if (attr.has_ref_attr_name() && used_fields != 0) {
563 // The attribute proto is supposed to refer to data outside and does not
564 // have its own value field set.
565 fail_check("Attribute (name: ", attr.name(), ") should refer to attribute in parent node.");
566 }
567 }
568
569 if (attr.has_t()) {
570 check_tensor(attr.t(), ctx);
571 }
572
573 if (attr.has_sparse_tensor()) {
574 check_sparse_tensor(attr.sparse_tensor(), ctx);
575 }
576
577 if (attr.has_g()) {
578 CheckerContext subgraph_ctx(ctx);
579 subgraph_ctx.set_is_main_graph(false);
580 check_graph(attr.g(), subgraph_ctx, lex_ctx);
581 }
582
583 for (const auto& tensor : attr.tensors()) {
584 check_tensor(tensor, ctx);
585 }
586 for (const auto& sparse_tensor : attr.sparse_tensors()) {
587 check_sparse_tensor(sparse_tensor, ctx);
588 }
589 if (attr.graphs().size() > 0) {
590 CheckerContext subgraph_ctx(ctx);
591 subgraph_ctx.set_is_main_graph(false);
592 for (const auto& graph : attr.graphs()) {
593 check_graph(graph, subgraph_ctx, lex_ctx);
594 }
595 }
596}
597
598void print_warning_if_has_experimental(const std::unordered_set<std::string>& used_experimental_ops) {
599 if (!used_experimental_ops.empty()) {
600 std::string all_experimental_ops;
601 for (const auto& op : used_experimental_ops) {
602 all_experimental_ops += " " + op + ",";
603 }
604 // Remove the last comma which is unnecessary
605 all_experimental_ops.pop_back();
606 std::cout << "Warning: Model contains experimental ops:" + all_experimental_ops << std::endl;
607 }
608}
609
610void check_node(const NodeProto& node, const CheckerContext& ctx, const LexicalScopeContext& lex_ctx) {
611 enforce_non_empty_field(node, op_type);
612
613 if (node.input().empty() && node.output().empty()) {
614 fail_check("NodeProto (name: ", node.name(), ", type: ", node.op_type(), ") has zero input and zero output.");
615 }
616
617 // Resolve domain for node
618 const auto& opset_imports = ctx.get_opset_imports();
619 auto dit = opset_imports.find(node.domain());
620 if (dit == opset_imports.end()) {
621 fail_check("No opset import for domain '" + node.domain() + "'");
622 }
623 auto domain_version = dit->second;
624
625 for (const auto& attr : node.attribute()) {
626 check_attribute(attr, ctx, lex_ctx);
627 }
628
629 // This issue will be caught by check_graph instead
630 if (check_is_experimental_op(node)) {
631 return;
632 }
633
634 const auto* schema = ctx.get_schema_registry()->GetSchema(node.op_type(), domain_version, node.domain());
635 if (!schema) {
636 if (node.domain() == ONNX_DOMAIN || node.domain() == AI_ONNX_ML_DOMAIN || node.domain() == "ai.onnx" ||
637 node.domain() == AI_ONNX_TRAINING_DOMAIN) {
638 // fail the checker if op in built-in domains has no schema
639 fail_check(
640 "No Op registered for " + node.op_type() + " with domain_version of " +
641 ONNX_NAMESPACE::to_string(domain_version));
642 } else {
643 // TODO: expose the registration of the op schemas appropriately in
644 // python, so we can load and register operators in other domains
645 //
646 // before we complete the above todo, let's skip the schema check for
647 // now
648 }
649 } else if (schema->Deprecated()) {
650 fail_check(
651 "Op registered for " + node.op_type() + " is deprecated in domain_version of " +
652 ONNX_NAMESPACE::to_string(domain_version));
653 } else {
654 schema->Verify(node);
655 }
656}
657
658void check_graph(const GraphProto& graph, const CheckerContext& ctx, const LexicalScopeContext& parent_lex) {
659 enforce_non_empty_field(graph, name);
660
661 for (const auto& value_info : graph.input()) {
662 check_value_info(value_info, ctx);
663 }
664 for (const auto& value_info : graph.output()) {
665 check_value_info(value_info, ctx);
666 }
667
668 // Inherit values available in outer scope
669 // Note that we do not allow shadowing, so the presence of an already-defined
670 // name is always an error.
671 LexicalScopeContext lex_ctx{parent_lex};
672
673 for (const auto& value_info : graph.input()) {
674 // TODO: If shadowing isn't allowed, this should maybe use
675 // this_or_ancestor_graph_has
676 if (lex_ctx.this_graph_has(value_info.name())) {
677 fail_check(
678 "Graph must be in single static assignment (SSA) form, however '",
679 value_info.name(),
680 "' has been used as graph input names multiple times.");
681 }
682 lex_ctx.add(value_info.name());
683 }
684
685 std::unordered_set<std::reference_wrapper<const std::string>, std::hash<std::string>, std::equal_to<std::string>>
686 initializer_name_checker;
687
688 for (const auto& init : graph.initializer()) {
689 enforce_has_field(init, name);
690 const auto& name = init.name();
691 if (name.empty()) {
692 fail_check("Tensor initializers must have a non-empty name");
693 }
694
695 if (!initializer_name_checker.insert(std::cref(name)).second) {
696 fail_check(name + " initializer name is not unique");
697 }
698
699 check_tensor(init, ctx);
700
701 if (ctx.get_ir_version() <= 0x00000003) {
702 // Initializers are a subset of graph inputs for IR_VERSION <= 3
703 if (!lex_ctx.this_graph_has(name)) {
704 fail_check(name + " in initializer but not in graph input");
705 }
706 } else {
707 // An initializer is allowed to have the same name as an input,
708 // but is not required to (for IR_VERSION >= 4)
709 lex_ctx.add(name);
710 }
711 }
712
713 for (const auto& sparse_init : graph.sparse_initializer()) {
714 const auto& values = sparse_init.values();
715 enforce_has_field(values, name);
716 const auto& name = values.name();
717 if (name.empty()) {
718 fail_check("Sparse tensor initializers must have a non-empty name");
719 }
720 if (!initializer_name_checker.insert(std::cref(name)).second) {
721 fail_check(name + " sparse initializer name is not unique across initializers and sparse_initializers");
722 }
723 check_sparse_tensor(sparse_init, ctx);
724 lex_ctx.add(name);
725 }
726 std::unordered_set<std::string> used_experimental_ops;
727 for (const auto& node : graph.node()) {
728 // nodes must be in topologically sorted order
729 for (const auto& input : node.input()) {
730 // explicit optional input
731 if (input.empty()) {
732 continue;
733 }
734 if (!lex_ctx.this_or_ancestor_graph_has(input)) {
735 fail_check(
736 "Nodes in a graph must be topologically sorted, however input '",
737 input,
738 "' of node: \n",
739 "name: ",
740 node.name(),
741 " OpType: ",
742 node.op_type(),
743 "\n is not output of any previous nodes.");
744 }
745 }
746
747 if (check_is_experimental_op(node)) {
748 used_experimental_ops.insert(node.op_type());
749 }
750
751 // This needs to happen before SSA check since we don't want to recurse and
752 // find that outputs from control flow ops are colliding with names in the
753 // inner block
754
755 ONNX_TRY {
756 check_node(node, ctx, lex_ctx);
757 }
758 ONNX_CATCH(ValidationError & ex) {
759 ONNX_HANDLE_EXCEPTION([&]() {
760 ex.AppendContext("Bad node spec for node. Name: " + node.name() + " OpType: " + node.op_type());
761 ONNX_THROW_EX(ex);
762 });
763 }
764 // check for SSA form
765 for (const auto& output : node.output()) {
766 // optional output
767 if (output.empty()) {
768 continue;
769 }
770
771 if (lex_ctx.this_or_ancestor_graph_has(output)) {
772 fail_check(
773 "Graph must be in single static assignment (SSA) form, however '",
774 output,
775 "' has been used as output names multiple times.");
776 }
777 lex_ctx.add(output);
778 }
779 }
780 print_warning_if_has_experimental(used_experimental_ops);
781}
782
783// Utilify function to get the imported version of domain from opset imports
784// Returns -1 if requested domain is not found in the opset_imports
785int get_version_for_domain(const std::string& domain, const std::unordered_map<std::string, int>& opset_imports) {
786 auto it = opset_imports.find(domain);
787 if (it == opset_imports.end()) {
788 return -1;
789 }
790
791 return it->second;
792}
793
794void check_opset_compatibility(
795 const NodeProto& node,
796 const CheckerContext& ctx,
797 const std::unordered_map<std::string, int>& func_opset_imports,
798 const std::unordered_map<std::string, int>& model_opset_imports) {
799 auto func_opset_version = get_version_for_domain(node.domain(), func_opset_imports);
800 auto model_opset_version = get_version_for_domain(node.domain(), model_opset_imports);
801
802 if (func_opset_version == -1) {
803 fail_check("No Opset registered for domain " + node.domain());
804 }
805
806 if (model_opset_version == -1) {
807 // model does not include opset import for a node present in function body.
808 // This is ok as along as the opset import is present in function level opset imports.
809 return;
810 }
811
812 if (func_opset_version == model_opset_version) {
813 // both versions are same, no need to verify schema.
814 return;
815 }
816
817 const auto* schema_for_model_import =
818 ctx.get_schema_registry()->GetSchema(node.op_type(), model_opset_version, node.domain());
819
820 const auto* schema_for_function_import =
821 ctx.get_schema_registry()->GetSchema(node.op_type(), func_opset_version, node.domain());
822
823 if (!schema_for_model_import && !schema_for_function_import) {
824 // the op belongs to a custom domain so we cannot verify schema
825 return;
826 }
827
828 // if schema is present for 1 but not other or the schema since versions do not match then raise an error
829 if (!schema_for_model_import || !schema_for_function_import ||
830 schema_for_function_import->since_version() != schema_for_model_import->since_version()) {
831 fail_check(
832 "Opset import for domain " + node.domain() + " in function op " + node.op_type() +
833 "is not compatible with the version imported by model. FunctionOp imports version " +
834 ONNX_NAMESPACE::to_string(func_opset_version) + "whereas model imports version " +
835 ONNX_NAMESPACE::to_string(model_opset_version));
836 }
837}
838
839void check_model_local_functions(
840 const ModelProto& model,
841 const CheckerContext& ctx,
842 const LexicalScopeContext& parent_lex) {
843 // make a copy of model opset imports to maintain a main copy of opset imports across the model and
844 // all model local functions to verify opset compatibility
845 std::unordered_map<std::string, int> model_opset_imports(ctx.get_opset_imports());
846
847 // merge the opset imports from every function in model_opset_imports
848 // only add the opset import if an entry for it does not exist in model_opset_imports
849 // if there is an entry then the compatibility will be checked later on in check_opset_compatibility
850 // called by check_function.
851 for (const auto& function_proto : model.functions()) {
852 for (const auto& opset_import : function_proto.opset_import()) {
853 if (get_version_for_domain(opset_import.domain(), model_opset_imports) == -1) {
854 model_opset_imports[opset_import.domain()] = opset_import.version();
855 }
856 }
857 }
858
859 CheckerContext ctx_copy = ctx;
860 ctx_copy.set_opset_imports(model_opset_imports);
861
862 for (const auto& function_proto : model.functions()) {
863 check_function(function_proto, ctx_copy, parent_lex);
864 }
865}
866
867void check_function(const FunctionProto& function, const CheckerContext& ctx, const LexicalScopeContext& parent_lex) {
868 enforce_non_empty_field(function, name);
869
870 if (ctx.get_ir_version() >= 0x00000008) {
871 enforce_has_field(function, domain);
872 }
873
874 const auto& model_opset_imports = ctx.get_opset_imports();
875 CheckerContext ctx_copy = ctx;
876
877 std::unordered_map<std::string, int> func_opset_imports;
878 for (auto& relied_opset : function.opset_import()) {
879 func_opset_imports[relied_opset.domain()] = static_cast<int>(relied_opset.version());
880 }
881
882 ctx_copy.set_opset_imports(func_opset_imports);
883
884 LexicalScopeContext lex_ctx{parent_lex};
885
886 for (const auto& input : function.input()) {
887 // TODO: If shadowing isn't allowed, this should maybe use
888 // this_or_ancestor_graph_has
889 if (lex_ctx.this_graph_has(input)) {
890 fail_check(
891 "Graph must be in single static assignment (SSA) form, however '", input, "' has been used multiple times.");
892 }
893 lex_ctx.add(input);
894 }
895
896 std::unordered_set<std::string> outputs;
897 for (const auto& output : function.output()) {
898 auto result = outputs.insert(output);
899 if (!result.second) {
900 fail_check("function (", function.name(), ") should not have duplicate outputs specified.");
901 }
902 }
903
904 std::unordered_set<std::string> attrs;
905 for (const auto& attr : function.attribute()) {
906 auto result = attrs.insert(attr);
907 if (!result.second) {
908 fail_check("function (", function.name(), ") should not have duplicate attributes specified.");
909 }
910 }
911 std::unordered_set<std::string> used_experimental_ops;
912 for (const auto& node : function.node()) {
913 // nodes must be in topologically sorted order
914 for (const auto& input : node.input()) {
915 // explicit optional input
916 if (input.empty()) {
917 continue;
918 }
919 if (!lex_ctx.this_graph_has(input)) {
920 fail_check(
921 "Nodes in a function must be topologically sorted, however input '",
922 input,
923 "' of node: \n",
924 "Name: ",
925 node.name(),
926 " OpType: ",
927 node.op_type(),
928 "\n is neither output of any previous nodes nor input of the function.");
929 }
930 }
931
932 // check whether the opset version imported for a domain by function and model are
933 // compatible
934 check_opset_compatibility(node, ctx_copy, func_opset_imports, model_opset_imports);
935 if (check_is_experimental_op(node)) {
936 used_experimental_ops.insert(node.op_type());
937 }
938 check_node(node, ctx_copy, lex_ctx);
939
940 // check for SSA form
941 for (const auto& output : node.output()) {
942 // optional output
943 if (output.empty()) {
944 continue;
945 }
946 if (lex_ctx.this_or_ancestor_graph_has(output)) {
947 fail_check(
948 "Function must be in single static assignment (SSA) form, however '",
949 output,
950 "' has been used as output names multiple times.");
951 }
952 lex_ctx.add(output);
953 }
954 }
955 print_warning_if_has_experimental(used_experimental_ops);
956}
957
958void check_model(const ModelProto& model, CheckerContext& ctx) {
959 if (!model.ir_version()) {
960 fail_check("The model does not have an ir_version set properly.");
961 }
962 if (model.ir_version() > IR_VERSION) {
963 fail_check("Your model ir_version is higher than the checker's.");
964 }
965 if (model.metadata_props_size() > 1) {
966 std::unordered_set<std::string> keys;
967 for (const StringStringEntryProto& entry : model.metadata_props()) {
968 auto i = keys.insert(entry.key());
969 if (!i.second) {
970 fail_check("Your model has duplicate keys in metadata_props.");
971 }
972 }
973 }
974 std::unordered_map<std::string, int> versions;
975 ctx.set_ir_version(static_cast<int>(model.ir_version()));
976 std::unordered_map<std::string, int> opset_imports;
977 for (const auto& opset_import : model.opset_import()) {
978 opset_imports[opset_import.domain()] = static_cast<int>(opset_import.version());
979 }
980 if (model.ir_version() >= 3) {
981 if (opset_imports.empty()) {
982 fail_check("model with IR version >= 3 must specify opset_import for ONNX");
983 }
984 } else {
985 if (opset_imports.empty())
986 opset_imports[ONNX_DOMAIN] = 1;
987 else {
988 fail_check("model with IR version < 3 cannot have opset_import specified");
989 }
990 }
991 ctx.set_opset_imports(opset_imports);
992 LexicalScopeContext lex_ctx;
993 check_graph(model.graph(), ctx, lex_ctx);
994
995 if (ctx.get_ir_version() >= 0x00000008) {
996 check_model_local_functions(model, ctx, lex_ctx);
997 }
998}
999
1000void check_model(const std::string& model_path, bool full_check) {
1001 ModelProto model;
1002 LoadProtoFromPath(model_path, model);
1003
1004 CheckerContext ctx;
1005 std::string model_dir;
1006 size_t pos = model_path.find_last_of("\\/");
1007 if (pos != std::string::npos) {
1008 model_dir = model_path.substr(0, pos + 1);
1009 }
1010 ctx.set_model_dir(model_dir);
1011 check_model(model, ctx);
1012
1013 if (full_check) {
1014 ShapeInferenceOptions options{true, 1, false};
1015 ONNX_NAMESPACE::shape_inference::InferShapes(model, ctx.get_schema_registry(), options);
1016 }
1017}
1018
1019void check_model(const ModelProto& model, bool full_check) {
1020 CheckerContext ctx;
1021 check_model(model, ctx);
1022 if (full_check) {
1023 ShapeInferenceOptions options{true, 1, false};
1024 // Do not update the model in place by the check from shape inference
1025 // because checker should not modify the original model
1026 ModelProto copy = model;
1027 ONNX_NAMESPACE::shape_inference::InferShapes(copy, ctx.get_schema_registry(), options);
1028 }
1029}
1030
1031std::set<std::string> experimental_ops = {
1032 "ATen",
1033 "Affine",
1034 "ConstantFill",
1035 "Crop",
1036 "DynamicSlice",
1037 "GRUUnit",
1038 "GivenTensorFill",
1039 "ImageScaler",
1040 "ParametricSoftplus",
1041 "Scale",
1042 "ScaledTanh"};
1043
1044bool check_is_experimental_op(const NodeProto& node) {
1045 return (node.domain() == ONNX_DOMAIN || node.domain() == "ai.onnx") && experimental_ops.count(node.op_type());
1046}
1047
1048#undef fail_check
1049#undef enforce_has_field
1050#undef enforce_has_repeated_field
1051#undef enforce_non_empty_field
1052
1053} // namespace checker
1054} // namespace ONNX_NAMESPACE
1055