1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "onnx/version_converter/convert.h" |
6 | |
7 | namespace ONNX_NAMESPACE { |
8 | namespace version_conversion { |
9 | |
10 | ModelProto ConvertVersion(const ModelProto& mp_in, int target_version) { |
11 | // Get initial_opsetid from mp_in |
12 | OpSetID initial_struct(0); |
13 | for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) { |
14 | if (it->domain() == "" || it->domain() == "ai.onnx" ) { |
15 | initial_struct.setVersion(it->version()); |
16 | break; |
17 | } |
18 | } |
19 | OpSetID target_struct = OpSetID(target_version); |
20 | DefaultVersionConverter v; |
21 | return v.convert_version(mp_in, initial_struct, target_struct); |
22 | } |
23 | |
24 | void DefaultVersionConverter::convert_graph( |
25 | std::shared_ptr<Graph> g, |
26 | const OpSetID& initial_version, |
27 | const OpSetID& target_version) const { |
28 | assertNonNull(g); |
29 | |
30 | // TODO: Move to Inter-Domain Converter |
31 | // Get initial model versions |
32 | // std::vector<OpSetID> initial_versions = g->opset_versions_mutable(); |
33 | |
34 | // No conversion necessary if Model has single, equivalent opset version |
35 | // if (initial_versions.size() == 1 && initial_versions[0].version == |
36 | // target_version.version && initial_versions[0].domain == |
37 | // target_version.domain) { |
38 | // return mp_in; |
39 | // } |
40 | |
41 | // Check if versions are valid |
42 | assertInVersionRange(initial_version.version()); |
43 | assertInVersionRange(target_version.version()); |
44 | |
45 | // Iterate over all versions to target_version for specified |
46 | int64_t curr_version = initial_version.version(); |
47 | int64_t step; |
48 | if (target_version.version() > initial_version.version()) { |
49 | step = 1; |
50 | } else { |
51 | step = -1; |
52 | } |
53 | // Identify index of this domain in g.opset_versions |
54 | unsigned int domain_index = 0; |
55 | for (unsigned int i = 0; i < g->opset_versions_mutable().size(); i++) { |
56 | if (g->opset_versions_mutable()[i].domain() == "" ) { |
57 | domain_index = i; |
58 | } |
59 | } |
60 | while (curr_version != target_version.version()) { |
61 | debug( |
62 | "curr_version: " + ONNX_NAMESPACE::to_string(curr_version) + |
63 | ", next_version: " + ONNX_NAMESPACE::to_string(curr_version + step)); |
64 | Node* cur_op; |
65 | graph_node_list_iterator it = g->begin(); |
66 | // Iterate through and call adapter returned by adapter_lookup for ops from |
67 | // current_version opset. We have to manipulate the iterator explicitly because cur_op |
68 | // might change when applying the adapter (e.g. for deprecated ops) |
69 | while (it != g->end()) { |
70 | cur_op = *it; |
71 | debug(std::string("Finding schema for " ) + std::string(cur_op->kind().toString())); |
72 | const std::string op_name = cur_op->kind().toString(); |
73 | if (op_name == "ConstantFill" ) { |
74 | std::cerr |
75 | << "Warning: skipping schema search for experimental op 'ConstantFill' and keeping the op as is. " |
76 | "Please be advised the converted model may not be working properly if target runtime does not support this " |
77 | "experimental op." |
78 | << std::endl; |
79 | } else if (cur_op->domain() != "" && cur_op->domain() != "ai.onnx" ) { |
80 | std::cerr << "Warning: opset domain '" << cur_op->domain() << "' is not supported." << std::endl; |
81 | } else if (op_name != "Undefined" && op_name != "Captured" ) { |
82 | auto& op_domain_map = all_schemas.at(op_name); |
83 | OpSetID curr_id(curr_version); |
84 | OpSetID next_id(curr_version + step); |
85 | if (searchOpDomainMap(op_domain_map, curr_version, step)) { |
86 | // Op is specifically defined for this domain and version |
87 | auto& op_adapter = adapter_lookup(cur_op, curr_id, next_id); |
88 | // If adapter_lookup returns null, no adapter is present. |
89 | // Error thrown by adapter_lookup |
90 | if (DEBUG) |
91 | std::cerr << "Applying adapter" << std::endl; |
92 | // adapt should handle replacing node in graph |
93 | cur_op = op_adapter.adapt(g, cur_op); |
94 | it = graph_node_list_iterator(cur_op, kNextDirection); |
95 | } |
96 | // Recursively convert any subgraph attributes |
97 | for (const auto& attr : cur_op->attributeNames()) { |
98 | if (cur_op->kindOf(attr) == AttributeKind::g) { |
99 | convert_graph(cur_op->g(attr), curr_id, next_id); |
100 | } |
101 | } |
102 | } |
103 | it++; |
104 | } |
105 | // Update model version |
106 | curr_version += step; |
107 | g->opset_versions_mutable()[domain_index].incrementVersion(step); |
108 | } |
109 | } |
110 | |
111 | ModelProto DefaultVersionConverter::convert_version( |
112 | const ModelProto& mp_in, |
113 | const OpSetID& initial_version, |
114 | const OpSetID& target_version) const { |
115 | const std::string& initial_domain = initial_version.domain(); |
116 | const std::string& target_domain = target_version.domain(); |
117 | assertDefaultDomain(initial_domain, target_domain); |
118 | |
119 | for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) { |
120 | if (it->domain() == initial_version.domain()) { |
121 | ONNX_ASSERTM( |
122 | initial_version.version() == it->version(), "initial_version does not reflect current state of model" ); |
123 | } |
124 | } |
125 | |
126 | std::shared_ptr<Graph> g(ImportModelProto(mp_in)); |
127 | |
128 | convert_graph(g, initial_version, target_version); |
129 | |
130 | // Export g as ModelProto |
131 | debug("Finished conversion; returning model" ); |
132 | ModelProto mp_out = PrepareOutput(mp_in); |
133 | ExportModelProto(&mp_out, g); |
134 | return mp_out; |
135 | } |
136 | |
137 | } // namespace version_conversion |
138 | } // namespace ONNX_NAMESPACE |
139 | |