1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#pragma once
6
7#include <stdexcept>
8#include <unordered_map>
9#include <unordered_set>
10#include "onnx/defs/function.h"
11#include "onnx/defs/schema.h"
12#include "onnx/onnx-data_pb.h"
13#include "onnx/onnx-operators_pb.h"
14#include "onnx/onnx_pb.h"
15#include "onnx/string_utils.h"
16
17namespace ONNX_NAMESPACE {
18namespace checker {
19class ValidationError final : public std::runtime_error {
20 public:
21 using std::runtime_error::runtime_error;
22 const char* what() const noexcept override {
23 if (!expanded_message_.empty()) {
24 return expanded_message_.c_str();
25 }
26 return std::runtime_error::what();
27 }
28 void AppendContext(const std::string& context) {
29 expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
30 }
31
32 private:
33 std::string expanded_message_;
34};
35
36#define fail_check(...) \
37 ONNX_THROW_EX(ONNX_NAMESPACE::checker::ValidationError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
38
39class CheckerContext final {
40 public:
41 int get_ir_version() const {
42 return ir_version_;
43 }
44 void set_ir_version(int v) {
45 ir_version_ = v;
46 }
47 const std::unordered_map<std::string, int>& get_opset_imports() const {
48 return opset_imports_;
49 }
50 void set_opset_imports(std::unordered_map<std::string, int> imps) {
51 opset_imports_ = std::move(imps);
52 }
53 bool is_main_graph() const {
54 return is_main_graph_;
55 }
56 void set_is_main_graph(bool is_main_graph) {
57 is_main_graph_ = is_main_graph;
58 }
59
60 void set_schema_registry(const ISchemaRegistry* schema_registry) {
61 schema_registry_ = schema_registry;
62 }
63
64 const ISchemaRegistry* get_schema_registry() const {
65 return schema_registry_;
66 }
67
68 void set_model_dir(const std::string& model_dir) {
69 model_dir_ = model_dir;
70 }
71
72 std::string get_model_dir() const {
73 return model_dir_;
74 }
75
76 explicit CheckerContext() : ir_version_(-1) {}
77
78 private:
79 int ir_version_;
80 std::unordered_map<std::string, int> opset_imports_;
81 bool is_main_graph_ = true;
82 const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance();
83 std::string model_dir_;
84};
85
86class LexicalScopeContext final {
87 public:
88 LexicalScopeContext() = default;
89
90 // Construct an instance with the lexical scope from the parent graph to allow
91 // lookup of names from that scope via this_or_ancestor_graph_has.
92 // The caller must ensure parent_context remains valid for the entire lifetime
93 // of the new instance. Alternatively, if that cannot be guaranteed, create an
94 // instance with the default constructor and populate output_names with the
95 // values from the parent scope so the values are copied instead.
96 LexicalScopeContext(const LexicalScopeContext& parent_context) : parent_context_{&parent_context} {}
97 LexicalScopeContext& operator=(const LexicalScopeContext& parent_context) {
98 parent_context_ = &parent_context;
99 return *this;
100 }
101
102 void add(const std::string& name) {
103 output_names.insert(name);
104 }
105
106 bool this_graph_has(const std::string& name) const {
107 return output_names.find(name) != output_names.cend();
108 }
109
110 bool this_or_ancestor_graph_has(const std::string& name) const {
111 return this_graph_has(name) || (parent_context_ && parent_context_->this_or_ancestor_graph_has(name));
112 }
113
114 // public for backwards compatibility. please prefer the public interface of
115 // this class over directly changing output_names
116 std::unordered_set<std::string> output_names;
117
118 private:
119 const LexicalScopeContext* parent_context_{nullptr};
120};
121
122using IR_VERSION_TYPE = decltype(Version::IR_VERSION);
123void check_value_info(const ValueInfoProto& value_info, const CheckerContext&);
124void check_tensor(const TensorProto& tensor, const CheckerContext&);
125void check_sparse_tensor(const SparseTensorProto& sparse_tensor, const CheckerContext&);
126void check_sequence(const SequenceProto& sequence, const CheckerContext&);
127void check_map(const MapProto& map, const CheckerContext&);
128void check_optional(const OptionalProto& opt, const CheckerContext&);
129void check_attribute(const AttributeProto& attr, const CheckerContext&, const LexicalScopeContext&);
130void check_node(const NodeProto& node, const CheckerContext&, const LexicalScopeContext&);
131void check_graph(const GraphProto& graph, const CheckerContext&, const LexicalScopeContext&);
132void check_function(const FunctionProto& function, const CheckerContext&, const LexicalScopeContext&);
133
134// Check schema compatibility for 2 opset versions for a given node.
135// Checks whether the schema for 2 versions is same, this is true when the opschema
136// does not change between versions.
137void check_opset_compatibility(
138 const NodeProto& node,
139 const CheckerContext& ctx,
140 const std::unordered_map<std::string, int>& func_opset_imports,
141 const std::unordered_map<std::string, int>& model_opset_imports);
142
143// Checks all model local functions present in ModelProto
144void check_model_local_functions(
145 const ModelProto& model,
146 const CheckerContext& ctx,
147 const LexicalScopeContext& parent_lex);
148
149void check_model(const ModelProto& model, bool full_check = false);
150void check_model(const std::string& model_path, bool full_check = false);
151
152bool check_is_experimental_op(const NodeProto& node);
153
154} // namespace checker
155} // namespace ONNX_NAMESPACE
156