1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #pragma once |
6 | |
7 | #include <stdexcept> |
8 | #include <unordered_map> |
9 | #include <unordered_set> |
10 | #include "onnx/defs/function.h" |
11 | #include "onnx/defs/schema.h" |
12 | #include "onnx/onnx-data_pb.h" |
13 | #include "onnx/onnx-operators_pb.h" |
14 | #include "onnx/onnx_pb.h" |
15 | #include "onnx/string_utils.h" |
16 | |
17 | namespace ONNX_NAMESPACE { |
18 | namespace checker { |
19 | class ValidationError final : public std::runtime_error { |
20 | public: |
21 | using std::runtime_error::runtime_error; |
22 | const char* what() const noexcept override { |
23 | if (!expanded_message_.empty()) { |
24 | return expanded_message_.c_str(); |
25 | } |
26 | return std::runtime_error::what(); |
27 | } |
28 | void AppendContext(const std::string& context) { |
29 | expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: " , context); |
30 | } |
31 | |
32 | private: |
33 | std::string expanded_message_; |
34 | }; |
35 | |
36 | #define fail_check(...) \ |
37 | ONNX_THROW_EX(ONNX_NAMESPACE::checker::ValidationError(ONNX_NAMESPACE::MakeString(__VA_ARGS__))); |
38 | |
39 | class CheckerContext final { |
40 | public: |
41 | int get_ir_version() const { |
42 | return ir_version_; |
43 | } |
44 | void set_ir_version(int v) { |
45 | ir_version_ = v; |
46 | } |
47 | const std::unordered_map<std::string, int>& get_opset_imports() const { |
48 | return opset_imports_; |
49 | } |
50 | void set_opset_imports(std::unordered_map<std::string, int> imps) { |
51 | opset_imports_ = std::move(imps); |
52 | } |
53 | bool is_main_graph() const { |
54 | return is_main_graph_; |
55 | } |
56 | void set_is_main_graph(bool is_main_graph) { |
57 | is_main_graph_ = is_main_graph; |
58 | } |
59 | |
60 | void set_schema_registry(const ISchemaRegistry* schema_registry) { |
61 | schema_registry_ = schema_registry; |
62 | } |
63 | |
64 | const ISchemaRegistry* get_schema_registry() const { |
65 | return schema_registry_; |
66 | } |
67 | |
68 | void set_model_dir(const std::string& model_dir) { |
69 | model_dir_ = model_dir; |
70 | } |
71 | |
72 | std::string get_model_dir() const { |
73 | return model_dir_; |
74 | } |
75 | |
76 | explicit CheckerContext() : ir_version_(-1) {} |
77 | |
78 | private: |
79 | int ir_version_; |
80 | std::unordered_map<std::string, int> opset_imports_; |
81 | bool is_main_graph_ = true; |
82 | const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance(); |
83 | std::string model_dir_; |
84 | }; |
85 | |
86 | class LexicalScopeContext final { |
87 | public: |
88 | LexicalScopeContext() = default; |
89 | |
90 | // Construct an instance with the lexical scope from the parent graph to allow |
91 | // lookup of names from that scope via this_or_ancestor_graph_has. |
92 | // The caller must ensure parent_context remains valid for the entire lifetime |
93 | // of the new instance. Alternatively, if that cannot be guaranteed, create an |
94 | // instance with the default constructor and populate output_names with the |
95 | // values from the parent scope so the values are copied instead. |
96 | LexicalScopeContext(const LexicalScopeContext& parent_context) : parent_context_{&parent_context} {} |
97 | LexicalScopeContext& operator=(const LexicalScopeContext& parent_context) { |
98 | parent_context_ = &parent_context; |
99 | return *this; |
100 | } |
101 | |
102 | void add(const std::string& name) { |
103 | output_names.insert(name); |
104 | } |
105 | |
106 | bool this_graph_has(const std::string& name) const { |
107 | return output_names.find(name) != output_names.cend(); |
108 | } |
109 | |
110 | bool this_or_ancestor_graph_has(const std::string& name) const { |
111 | return this_graph_has(name) || (parent_context_ && parent_context_->this_or_ancestor_graph_has(name)); |
112 | } |
113 | |
114 | // public for backwards compatibility. please prefer the public interface of |
115 | // this class over directly changing output_names |
116 | std::unordered_set<std::string> output_names; |
117 | |
118 | private: |
119 | const LexicalScopeContext* parent_context_{nullptr}; |
120 | }; |
121 | |
122 | using IR_VERSION_TYPE = decltype(Version::IR_VERSION); |
123 | void check_value_info(const ValueInfoProto& value_info, const CheckerContext&); |
124 | void check_tensor(const TensorProto& tensor, const CheckerContext&); |
125 | void check_sparse_tensor(const SparseTensorProto& sparse_tensor, const CheckerContext&); |
126 | void check_sequence(const SequenceProto& sequence, const CheckerContext&); |
127 | void check_map(const MapProto& map, const CheckerContext&); |
128 | void check_optional(const OptionalProto& opt, const CheckerContext&); |
129 | void check_attribute(const AttributeProto& attr, const CheckerContext&, const LexicalScopeContext&); |
130 | void check_node(const NodeProto& node, const CheckerContext&, const LexicalScopeContext&); |
131 | void check_graph(const GraphProto& graph, const CheckerContext&, const LexicalScopeContext&); |
132 | void check_function(const FunctionProto& function, const CheckerContext&, const LexicalScopeContext&); |
133 | |
134 | // Check schema compatibility for 2 opset versions for a given node. |
135 | // Checks whether the schema for 2 versions is same, this is true when the opschema |
136 | // does not change between versions. |
137 | void check_opset_compatibility( |
138 | const NodeProto& node, |
139 | const CheckerContext& ctx, |
140 | const std::unordered_map<std::string, int>& func_opset_imports, |
141 | const std::unordered_map<std::string, int>& model_opset_imports); |
142 | |
143 | // Checks all model local functions present in ModelProto |
144 | void check_model_local_functions( |
145 | const ModelProto& model, |
146 | const CheckerContext& ctx, |
147 | const LexicalScopeContext& parent_lex); |
148 | |
149 | void check_model(const ModelProto& model, bool full_check = false); |
150 | void check_model(const std::string& model_path, bool full_check = false); |
151 | |
152 | bool check_is_experimental_op(const NodeProto& node); |
153 | |
154 | } // namespace checker |
155 | } // namespace ONNX_NAMESPACE |
156 | |