1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "onnx/version_converter/convert.h"
6
7namespace ONNX_NAMESPACE {
8namespace version_conversion {
9
10ModelProto ConvertVersion(const ModelProto& mp_in, int target_version) {
11 // Get initial_opsetid from mp_in
12 OpSetID initial_struct(0);
13 for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) {
14 if (it->domain() == "" || it->domain() == "ai.onnx") {
15 initial_struct.setVersion(it->version());
16 break;
17 }
18 }
19 OpSetID target_struct = OpSetID(target_version);
20 DefaultVersionConverter v;
21 return v.convert_version(mp_in, initial_struct, target_struct);
22}
23
24void DefaultVersionConverter::convert_graph(
25 std::shared_ptr<Graph> g,
26 const OpSetID& initial_version,
27 const OpSetID& target_version) const {
28 assertNonNull(g);
29
30 // TODO: Move to Inter-Domain Converter
31 // Get initial model versions
32 // std::vector<OpSetID> initial_versions = g->opset_versions_mutable();
33
34 // No conversion necessary if Model has single, equivalent opset version
35 // if (initial_versions.size() == 1 && initial_versions[0].version ==
36 // target_version.version && initial_versions[0].domain ==
37 // target_version.domain) {
38 // return mp_in;
39 // }
40
41 // Check if versions are valid
42 assertInVersionRange(initial_version.version());
43 assertInVersionRange(target_version.version());
44
45 // Iterate over all versions to target_version for specified
46 int64_t curr_version = initial_version.version();
47 int64_t step;
48 if (target_version.version() > initial_version.version()) {
49 step = 1;
50 } else {
51 step = -1;
52 }
53 // Identify index of this domain in g.opset_versions
54 unsigned int domain_index = 0;
55 for (unsigned int i = 0; i < g->opset_versions_mutable().size(); i++) {
56 if (g->opset_versions_mutable()[i].domain() == "") {
57 domain_index = i;
58 }
59 }
60 while (curr_version != target_version.version()) {
61 debug(
62 "curr_version: " + ONNX_NAMESPACE::to_string(curr_version) +
63 ", next_version: " + ONNX_NAMESPACE::to_string(curr_version + step));
64 Node* cur_op;
65 graph_node_list_iterator it = g->begin();
66 // Iterate through and call adapter returned by adapter_lookup for ops from
67 // current_version opset. We have to manipulate the iterator explicitly because cur_op
68 // might change when applying the adapter (e.g. for deprecated ops)
69 while (it != g->end()) {
70 cur_op = *it;
71 debug(std::string("Finding schema for ") + std::string(cur_op->kind().toString()));
72 const std::string op_name = cur_op->kind().toString();
73 if (op_name == "ConstantFill") {
74 std::cerr
75 << "Warning: skipping schema search for experimental op 'ConstantFill' and keeping the op as is. "
76 "Please be advised the converted model may not be working properly if target runtime does not support this "
77 "experimental op."
78 << std::endl;
79 } else if (cur_op->domain() != "" && cur_op->domain() != "ai.onnx") {
80 std::cerr << "Warning: opset domain '" << cur_op->domain() << "' is not supported." << std::endl;
81 } else if (op_name != "Undefined" && op_name != "Captured") {
82 auto& op_domain_map = all_schemas.at(op_name);
83 OpSetID curr_id(curr_version);
84 OpSetID next_id(curr_version + step);
85 if (searchOpDomainMap(op_domain_map, curr_version, step)) {
86 // Op is specifically defined for this domain and version
87 auto& op_adapter = adapter_lookup(cur_op, curr_id, next_id);
88 // If adapter_lookup returns null, no adapter is present.
89 // Error thrown by adapter_lookup
90 if (DEBUG)
91 std::cerr << "Applying adapter" << std::endl;
92 // adapt should handle replacing node in graph
93 cur_op = op_adapter.adapt(g, cur_op);
94 it = graph_node_list_iterator(cur_op, kNextDirection);
95 }
96 // Recursively convert any subgraph attributes
97 for (const auto& attr : cur_op->attributeNames()) {
98 if (cur_op->kindOf(attr) == AttributeKind::g) {
99 convert_graph(cur_op->g(attr), curr_id, next_id);
100 }
101 }
102 }
103 it++;
104 }
105 // Update model version
106 curr_version += step;
107 g->opset_versions_mutable()[domain_index].incrementVersion(step);
108 }
109}
110
111ModelProto DefaultVersionConverter::convert_version(
112 const ModelProto& mp_in,
113 const OpSetID& initial_version,
114 const OpSetID& target_version) const {
115 const std::string& initial_domain = initial_version.domain();
116 const std::string& target_domain = target_version.domain();
117 assertDefaultDomain(initial_domain, target_domain);
118
119 for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) {
120 if (it->domain() == initial_version.domain()) {
121 ONNX_ASSERTM(
122 initial_version.version() == it->version(), "initial_version does not reflect current state of model");
123 }
124 }
125
126 std::shared_ptr<Graph> g(ImportModelProto(mp_in));
127
128 convert_graph(g, initial_version, target_version);
129
130 // Export g as ModelProto
131 debug("Finished conversion; returning model");
132 ModelProto mp_out = PrepareOutput(mp_in);
133 ExportModelProto(&mp_out, g);
134 return mp_out;
135}
136
137} // namespace version_conversion
138} // namespace ONNX_NAMESPACE
139