1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5// Default converter for ONNX models between different opset versions
6// in the default domain ("" or "ai.onnx").
7
8#pragma once
9
10#include "onnx/version_converter/BaseConverter.h"
11#include "onnx/version_converter/adapters/axes_attribute_to_input.h"
12#include "onnx/version_converter/adapters/axes_input_to_attribute.h"
13#include "onnx/version_converter/adapters/batch_normalization_13_14.h"
14#include "onnx/version_converter/adapters/broadcast_backward_compatibility.h"
15#include "onnx/version_converter/adapters/broadcast_forward_compatibility.h"
16#include "onnx/version_converter/adapters/cast_9_8.h"
17#include "onnx/version_converter/adapters/clip_10_11.h"
18#include "onnx/version_converter/adapters/compatible.h"
19#include "onnx/version_converter/adapters/dropout_11_12.h"
20#include "onnx/version_converter/adapters/extend_supported_types.h"
21#include "onnx/version_converter/adapters/gemm_6_7.h"
22#include "onnx/version_converter/adapters/gemm_7_6.h"
23#include "onnx/version_converter/adapters/maxpool_8_7.h"
24#include "onnx/version_converter/adapters/no_previous_version.h"
25#include "onnx/version_converter/adapters/pad_10_11.h"
26#include "onnx/version_converter/adapters/reshape_4_5.h"
27#include "onnx/version_converter/adapters/reshape_5_4.h"
28#include "onnx/version_converter/adapters/resize_10_11.h"
29#include "onnx/version_converter/adapters/scan_8_9.h"
30#include "onnx/version_converter/adapters/scan_9_8.h"
31#include "onnx/version_converter/adapters/scatter_10_11.h"
32#include "onnx/version_converter/adapters/slice_9_10.h"
33#include "onnx/version_converter/adapters/softmax_12_13.h"
34#include "onnx/version_converter/adapters/split_12_13.h"
35#include "onnx/version_converter/adapters/split_13_12.h"
36#include "onnx/version_converter/adapters/split_17_18.h"
37#include "onnx/version_converter/adapters/sum_8_7.h"
38#include "onnx/version_converter/adapters/topk_9_10.h"
39#include "onnx/version_converter/adapters/type_restriction.h"
40#include "onnx/version_converter/adapters/upsample_6_7.h"
41#include "onnx/version_converter/adapters/upsample_8_9.h"
42#include "onnx/version_converter/adapters/upsample_9_10.h"
43#include "onnx/version_converter/adapters/upsample_9_8.h"
44
45#include "onnx/version_converter/adapters/transformers.h"
46
47namespace ONNX_NAMESPACE {
48namespace version_conversion {
49
50class DefaultVersionConverter : public BaseVersionConverter {
51 private:
52 bool DEBUG = false;
53
54 std::pair<int, int> version_range;
55
56 bool searchOpDomainMap(
57 const std::unordered_map<std::string, std::map<int64_t, const OpSchema*>>& op_domain_map,
58 int64_t curr_version,
59 int64_t step) const {
60 bool up = step == 1;
61 const auto version_it = op_domain_map.find("");
62 return version_it != op_domain_map.end() &&
63 ((version_it->second.find(curr_version) != version_it->second.end() && !up) ||
64 (version_it->second.find(curr_version + step) != version_it->second.end() && up));
65 }
66
67 void debug(const std::string& str) const {
68 if (DEBUG)
69 std::cerr << str << std::endl;
70 }
71
72 void assertInVersionRange(int64_t version) const {
73 ONNX_ASSERTM(
74 version >= version_range.first && version <= version_range.second,
75 "Warning: invalid version (must be between %d and %d)",
76 version_range.first,
77 version_range.second);
78 }
79
80 void assertDefaultDomain(const std::string& initial_domain, const std::string& target_domain) const {
81 ONNX_ASSERTM(
82 (initial_domain == "" || initial_domain == "ai.onnx") && (target_domain == "" || target_domain == "ai.onnx"),
83 "Warning: default onnx version converter can only convert "
84 " between default domain opset versions ('' or 'ai.onnx')\n");
85 ONNX_ASSERTM(initial_domain == target_domain, "initial_version and target_version must have the same domains");
86 }
87
88 void convert_graph(std::shared_ptr<Graph> g, const OpSetID& initial_version, const OpSetID& target_version) const;
89
90 public:
91 DefaultVersionConverter() {
92 const std::unordered_map<std::string, std::pair<int, int>>& versions_map =
93 OpSchemaRegistry::DomainToVersionRange::Instance().Map();
94 version_range = versions_map.at("");
95 // Register adapters to the version converter
96 const std::vector<OpSchema> all_opschemas = OpSchemaRegistry::get_all_schemas_with_history();
97
98 for (const OpSchema& schema : all_opschemas) {
99 all_schemas[schema.Name()][schema.domain()][(int64_t)schema.since_version()] = &schema;
100 }
101
102 // Iterate through all_schemas to determine NoPreviousVersionAdapters
103 for (auto& op_pair : all_schemas) {
104 const auto default_versions = op_pair.second.find("");
105 if (default_versions != op_pair.second.end()) {
106 int64_t min_version = version_range.second;
107 for (auto& version_pair : default_versions->second) {
108 if (version_pair.first < min_version) {
109 min_version = version_pair.first;
110 }
111 }
112 if (min_version > 1) {
113 registerAdapter(
114 make_unique<NoPreviousVersionAdapter>(op_pair.first, OpSetID(min_version), OpSetID(min_version - 1)));
115 }
116 }
117 }
118
119 /******** 1 -> 2 ********/
120 // Missing in this group: GlobalLpPool, LpPool, Pad, Split
121
122 /******** 2 -> 3 ********/
123 // Missing in this group: GRU
124
125 /******** 3 -> 4 ********/
126 registerAdapter("Concat", 3, 4, SetAttributeIfAbsent(kaxis, 1));
127
128 /******** 4 -> 3 ********/
129 std::vector<TensorProto_DataType> concat_unallowed_types = {
130 TensorProto_DataType_INT32,
131 TensorProto_DataType_INT64,
132 TensorProto_DataType_UINT32,
133 TensorProto_DataType_UINT64,
134 TensorProto_DataType_UINT8,
135 TensorProto_DataType_UINT16,
136 TensorProto_DataType_INT8,
137 TensorProto_DataType_INT16,
138 TensorProto_DataType_STRING,
139 TensorProto_DataType_BOOL};
140 registerAdapter(make_unique<TypeRestriction>("Concat", OpSetID(4), OpSetID(3), concat_unallowed_types));
141
142 /******** 4 -> 5 ********/
143 registerAdapter(make_unique<Reshape_4_5>());
144
145 /******** 5 -> 4 ********/
146 registerAdapter(make_unique<Reshape_5_4>());
147
148 /******** 5 -> 6 ********/
149 // Missing in this group: Cast, Tile
150 auto removeConsumedInputs = RemoveAttribute(kconsumed_inputs);
151 registerAdapter("Add", 5, 6, removeConsumedInputs);
152 registerAdapter("Mul", 5, 6, removeConsumedInputs);
153 registerAdapter(make_unique<CompatibleAdapter>("Gemm", OpSetID(5), OpSetID(6)));
154 registerAdapter("Relu", 5, 6, removeConsumedInputs);
155 registerAdapter("BatchNormalization", 5, 6, removeConsumedInputs);
156 registerAdapter("Sum", 5, 6, removeConsumedInputs);
157 registerAdapter("Dropout", 5, 6, removeConsumedInputs);
158 registerAdapter("Abs", 5, 6, removeConsumedInputs);
159 registerAdapter("Ceil", 5, 6, removeConsumedInputs);
160 registerAdapter("Clip", 5, 6, removeConsumedInputs);
161 registerAdapter("Div", 5, 6, removeConsumedInputs);
162 registerAdapter("Elu", 5, 6, removeConsumedInputs);
163 registerAdapter("Exp", 5, 6, removeConsumedInputs);
164 registerAdapter("Floor", 5, 6, removeConsumedInputs);
165 registerAdapter("HardSigmoid", 5, 6, removeConsumedInputs);
166 registerAdapter("InstanceNormalization", 5, 6, removeConsumedInputs);
167 registerAdapter("LeakyRelu", 5, 6, removeConsumedInputs);
168 registerAdapter("Log", 5, 6, removeConsumedInputs);
169 registerAdapter("Max", 5, 6, removeConsumedInputs);
170 registerAdapter("Mean", 5, 6, removeConsumedInputs);
171 registerAdapter("Min", 5, 6, removeConsumedInputs);
172 registerAdapter("Neg", 5, 6, removeConsumedInputs);
173 registerAdapter("PRelu", 5, 6, removeConsumedInputs);
174 registerAdapter("Reciprocal", 5, 6, removeConsumedInputs);
175 registerAdapter("Selu", 5, 6, removeConsumedInputs);
176 registerAdapter("Sigmoid", 5, 6, removeConsumedInputs);
177 registerAdapter("Sqrt", 5, 6, removeConsumedInputs);
178 registerAdapter("Sub", 5, 6, removeConsumedInputs);
179 registerAdapter("Tanh", 5, 6, removeConsumedInputs);
180
181 /******** 6 -> 5 ********/
182 std::vector<TensorProto_DataType> broadcast_unallowed_types = {
183 TensorProto_DataType_INT32,
184 TensorProto_DataType_INT64,
185 TensorProto_DataType_UINT32,
186 TensorProto_DataType_UINT64};
187 std::vector<TensorProto_DataType> int_unallowed_types = {
188 TensorProto_DataType_UINT8,
189 TensorProto_DataType_UINT16,
190 TensorProto_DataType_UINT32,
191 TensorProto_DataType_UINT64,
192 TensorProto_DataType_INT8,
193 TensorProto_DataType_INT16,
194 TensorProto_DataType_INT32,
195 TensorProto_DataType_INT64};
196 std::vector<TensorProto_DataType> neg_unallowed_types = {
197 TensorProto_DataType_INT32, TensorProto_DataType_INT8, TensorProto_DataType_UINT16, TensorProto_DataType_INT64};
198 registerAdapter(make_unique<TypeRestriction>("Add", OpSetID(6), OpSetID(5), broadcast_unallowed_types));
199 registerAdapter(make_unique<TypeRestriction>("Mul", OpSetID(6), OpSetID(5), broadcast_unallowed_types));
200 registerAdapter(make_unique<TypeRestriction>("Sub", OpSetID(6), OpSetID(5), broadcast_unallowed_types));
201 registerAdapter(make_unique<TypeRestriction>("Div", OpSetID(6), OpSetID(5), broadcast_unallowed_types));
202 registerAdapter(make_unique<TypeRestriction>("Abs", OpSetID(6), OpSetID(5), int_unallowed_types));
203 registerAdapter(make_unique<TypeRestriction>("Neg", OpSetID(6), OpSetID(5), neg_unallowed_types));
204 registerAdapter("BatchNormalization", 6, 5, SetAttribute(kconsumed_inputs, std::vector<int64_t>({0, 0})));
205 registerAdapter(make_unique<CompatibleAdapter>("Gemm", OpSetID(6), OpSetID(5)));
206 registerAdapter(make_unique<CompatibleAdapter>("Relu", OpSetID(6), OpSetID(5)));
207 registerAdapter(make_unique<CompatibleAdapter>("Sum", OpSetID(6), OpSetID(5)));
208 registerAdapter(make_unique<CompatibleAdapter>("Dropout", OpSetID(6), OpSetID(5)));
209
210 /******** 6 -> 7 ********/
211 // Missing in this group: And, Equal, Greater, GRU, Less, LSTM, Or, RNN, Upsample, Xor
212 registerAdapter(make_unique<BroadcastForwardCompatibility>("Add", OpSetID(6), OpSetID(7)));
213 registerAdapter(make_unique<CompatibleAdapter>("AveragePool", OpSetID(6), OpSetID(7)));
214 registerAdapter(make_unique<BroadcastForwardCompatibility>("Div", OpSetID(6), OpSetID(7)));
215 registerAdapter(make_unique<BroadcastForwardCompatibility>("Mul", OpSetID(6), OpSetID(7)));
216 registerAdapter(make_unique<BroadcastForwardCompatibility>("Pow", OpSetID(6), OpSetID(7)));
217 registerAdapter(make_unique<CompatibleAdapter>("PRelu", OpSetID(6), OpSetID(7)));
218 registerAdapter(make_unique<BroadcastForwardCompatibility>("Sub", OpSetID(6), OpSetID(7)));
219 registerAdapter(make_unique<Gemm_6_7>());
220 registerAdapter("BatchNormalization", 6, 7, RemoveAttributeNotEq(kis_test, 0));
221 registerAdapter("Dropout", 6, 7, RemoveAttributeNotEq(kis_test, 0));
222 registerAdapter(make_unique<Upsample_6_7>());
223
224 /******** 7 -> 6 ********/
225 registerAdapter(make_unique<BroadcastBackwardCompatibility>("Add", OpSetID(7), OpSetID(6)));
226 registerAdapter(make_unique<BroadcastBackwardCompatibility>("Div", OpSetID(7), OpSetID(6)));
227 registerAdapter(make_unique<BroadcastBackwardCompatibility>("Mul", OpSetID(7), OpSetID(6)));
228 registerAdapter(make_unique<BroadcastBackwardCompatibility>("Pow", OpSetID(7), OpSetID(6)));
229 registerAdapter(make_unique<CompatibleAdapter>("PRelu", OpSetID(7), OpSetID(6)));
230 registerAdapter(make_unique<BroadcastBackwardCompatibility>("Sub", OpSetID(7), OpSetID(6)));
231 registerAdapter("BatchNormalization", 7, 6, SetAttribute(kis_test, 1));
232 registerAdapter("Dropout", 7, 6, SetAttribute(kis_test, 1));
233 registerAdapter(make_unique<Gemm_7_6>());
234 registerAdapter("AveragePool", 7, 6, RemoveAttribute(kcount_include_pad, 0));
235
236 /******** 7 -> 8 ********/
237 registerAdapter(make_unique<CompatibleAdapter>("Max", OpSetID(7), OpSetID(8)));
238 registerAdapter(make_unique<CompatibleAdapter>("Min", OpSetID(7), OpSetID(8)));
239 registerAdapter(make_unique<CompatibleAdapter>("Mean", OpSetID(7), OpSetID(8)));
240 registerAdapter(make_unique<CompatibleAdapter>("Sum", OpSetID(7), OpSetID(8)));
241 registerAdapter(make_unique<CompatibleAdapter>("MaxPool", OpSetID(7), OpSetID(8)));
242
243 /******** 8 -> 7 ********/
244 registerAdapter(make_unique<BroadcastBackwardCompatibility>("Max", OpSetID(8), OpSetID(7)));
245 registerAdapter(make_unique<BroadcastBackwardCompatibility>("Min", OpSetID(8), OpSetID(7)));
246 registerAdapter(make_unique<BroadcastBackwardCompatibility>("Mean", OpSetID(8), OpSetID(7)));
247 registerAdapter(make_unique<Sum_8_7>());
248 registerAdapter(make_unique<MaxPool_8_7>());
249
250 /******** 8 -> 9 ********/
251 registerAdapter(make_unique<CompatibleAdapter>("Flatten", OpSetID(8), OpSetID(9)));
252 registerAdapter(make_unique<CompatibleAdapter>("Constant", OpSetID(8), OpSetID(9)));
253 registerAdapter(make_unique<CompatibleAdapter>("MatMul", OpSetID(8), OpSetID(9)));
254 registerAdapter(make_unique<CompatibleAdapter>("Gemm", OpSetID(8), OpSetID(9)));
255 registerAdapter(make_unique<CompatibleAdapter>("PRelu", OpSetID(8), OpSetID(9)));
256 registerAdapter(make_unique<CompatibleAdapter>("Greater", OpSetID(8), OpSetID(9)));
257 registerAdapter(make_unique<CompatibleAdapter>("Less", OpSetID(8), OpSetID(9)));
258 registerAdapter(make_unique<CompatibleAdapter>("Cast", OpSetID(8), OpSetID(9)));
259 registerAdapter("BatchNormalization", 8, 9, RemoveAttribute(kspatial, 1));
260 registerAdapter(make_unique<Scan_8_9>());
261 registerAdapter(make_unique<Upsample_8_9>());
262
263 /******** 9 -> 8 ********/
264 registerAdapter(make_unique<CompatibleAdapter>("BatchNormalization", OpSetID(9), OpSetID(8)));
265 registerAdapter(make_unique<ExtendSupportedTypes>("Flatten", OpSetID(9), OpSetID(8)));
266 registerAdapter(make_unique<ExtendSupportedTypes>("Constant", OpSetID(9), OpSetID(8)));
267 registerAdapter(make_unique<ExtendSupportedTypes>("MatMul", OpSetID(9), OpSetID(8)));
268 registerAdapter(make_unique<ExtendSupportedTypes>("Gemm", OpSetID(9), OpSetID(8)));
269 registerAdapter(make_unique<ExtendSupportedTypes>("PRelu", OpSetID(9), OpSetID(8)));
270 registerAdapter(make_unique<ExtendSupportedTypes>("Greater", OpSetID(9), OpSetID(8)));
271 registerAdapter(make_unique<ExtendSupportedTypes>("Less", OpSetID(9), OpSetID(8)));
272 registerAdapter(make_unique<Cast_9_8>());
273 registerAdapter(make_unique<Scan_9_8>());
274 registerAdapter(make_unique<Upsample_9_8>());
275
276 /******** 9 -> 10 ********/
277 registerAdapter(make_unique<CompatibleAdapter>("AveragePool", OpSetID(9), OpSetID(10)));
278 registerAdapter(make_unique<CompatibleAdapter>("MaxPool", OpSetID(9), OpSetID(10)));
279 registerAdapter(make_unique<CompatibleAdapter>("Dropout", OpSetID(9), OpSetID(10)));
280 registerAdapter(make_unique<Slice_9_10>());
281 registerAdapter(make_unique<TopK_9_10>());
282 registerAdapter(make_unique<Upsample_9_10>());
283
284 /******** 10 -> 9 ********/
285 registerAdapter(make_unique<CompatibleAdapter>("Dropout", OpSetID(10), OpSetID(9)));
286
287 /******** 10 -> 11 ********/
288 registerAdapter(make_unique<CompatibleAdapter>("ArgMax", OpSetID(10), OpSetID(11)));
289 registerAdapter(make_unique<CompatibleAdapter>("ArgMin", OpSetID(10), OpSetID(11)));
290 registerAdapter(make_unique<CompatibleAdapter>("AveragePool", OpSetID(10), OpSetID(11)));
291 registerAdapter(make_unique<CompatibleAdapter>("Concat", OpSetID(10), OpSetID(11)));
292 registerAdapter(make_unique<CompatibleAdapter>("Constant", OpSetID(10), OpSetID(11)));
293 registerAdapter(make_unique<CompatibleAdapter>("Compress", OpSetID(10), OpSetID(11)));
294 registerAdapter(make_unique<CompatibleAdapter>("Conv", OpSetID(10), OpSetID(11)));
295 registerAdapter(make_unique<CompatibleAdapter>("ConvTranspose", OpSetID(10), OpSetID(11)));
296 registerAdapter(make_unique<CompatibleAdapter>("DepthToSpace", OpSetID(10), OpSetID(11)));
297 registerAdapter(make_unique<CompatibleAdapter>("Equal", OpSetID(10), OpSetID(11)));
298 registerAdapter(make_unique<CompatibleAdapter>("Flatten", OpSetID(10), OpSetID(11)));
299 registerAdapter(make_unique<CompatibleAdapter>("Gather", OpSetID(10), OpSetID(11)));
300 registerAdapter(make_unique<CompatibleAdapter>("Gemm", OpSetID(10), OpSetID(11)));
301 registerAdapter(make_unique<CompatibleAdapter>("Hardmax", OpSetID(10), OpSetID(11)));
302 registerAdapter(make_unique<CompatibleAdapter>("If", OpSetID(10), OpSetID(11)));
303 registerAdapter(make_unique<CompatibleAdapter>("LogSoftmax", OpSetID(10), OpSetID(11)));
304 registerAdapter(make_unique<CompatibleAdapter>("Loop", OpSetID(10), OpSetID(11)));
305 registerAdapter(make_unique<CompatibleAdapter>("LpPool", OpSetID(10), OpSetID(11)));
306 registerAdapter(make_unique<CompatibleAdapter>("MaxPool", OpSetID(10), OpSetID(11)));
307 registerAdapter(make_unique<CompatibleAdapter>("MaxUnpool", OpSetID(10), OpSetID(11)));
308 registerAdapter(make_unique<CompatibleAdapter>("NonMaxSuppression", OpSetID(10), OpSetID(11)));
309 registerAdapter(make_unique<CompatibleAdapter>("OneHot", OpSetID(10), OpSetID(11)));
310 registerAdapter(make_unique<CompatibleAdapter>("ReduceL1", OpSetID(10), OpSetID(11)));
311 registerAdapter(make_unique<CompatibleAdapter>("ReduceL2", OpSetID(10), OpSetID(11)));
312 registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSum", OpSetID(10), OpSetID(11)));
313 registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSumExp", OpSetID(10), OpSetID(11)));
314 registerAdapter(make_unique<CompatibleAdapter>("ReduceMax", OpSetID(10), OpSetID(11)));
315 registerAdapter(make_unique<CompatibleAdapter>("ReduceMean", OpSetID(10), OpSetID(11)));
316 registerAdapter(make_unique<CompatibleAdapter>("ReduceMin", OpSetID(10), OpSetID(11)));
317 registerAdapter(make_unique<CompatibleAdapter>("ReduceProd", OpSetID(10), OpSetID(11)));
318 registerAdapter(make_unique<CompatibleAdapter>("ReduceSum", OpSetID(10), OpSetID(11)));
319 registerAdapter(make_unique<CompatibleAdapter>("ReduceSumSquare", OpSetID(10), OpSetID(11)));
320 registerAdapter(make_unique<CompatibleAdapter>("Scan", OpSetID(10), OpSetID(11)));
321 registerAdapter(make_unique<CompatibleAdapter>("Softmax", OpSetID(10), OpSetID(11)));
322 registerAdapter(make_unique<CompatibleAdapter>("Slice", OpSetID(10), OpSetID(11)));
323 registerAdapter(make_unique<CompatibleAdapter>("Split", OpSetID(10), OpSetID(11)));
324 registerAdapter(make_unique<CompatibleAdapter>("Squeeze", OpSetID(10), OpSetID(11)));
325 registerAdapter(make_unique<CompatibleAdapter>("TopK", OpSetID(10), OpSetID(11)));
326 registerAdapter(make_unique<CompatibleAdapter>("Unsqueeze", OpSetID(10), OpSetID(11)));
327 registerAdapter(make_unique<Clip_10_11>());
328 registerAdapter(make_unique<Pad_10_11>());
329 registerAdapter(make_unique<Resize_10_11>());
330 registerAdapter(make_unique<Scatter_10_11>());
331
332 /******** 11 -> 10 ********/
333 std::vector<TensorProto_DataType> equal_unallowed_types = {
334 TensorProto_DataType_UINT8,
335 TensorProto_DataType_UINT16,
336 TensorProto_DataType_UINT32,
337 TensorProto_DataType_UINT64,
338 TensorProto_DataType_INT8,
339 TensorProto_DataType_INT16,
340 TensorProto_DataType_FLOAT16,
341 TensorProto_DataType_FLOAT,
342 TensorProto_DataType_DOUBLE};
343 registerAdapter(make_unique<CompatibleAdapter>("ArgMax", OpSetID(11), OpSetID(10)));
344 registerAdapter(make_unique<CompatibleAdapter>("ArgMin", OpSetID(11), OpSetID(10)));
345 registerAdapter(make_unique<CompatibleAdapter>("AveragePool", OpSetID(11), OpSetID(10)));
346 registerAdapter(make_unique<CompatibleAdapter>("Concat", OpSetID(11), OpSetID(10)));
347 registerAdapter(make_unique<CompatibleAdapter>("Constant", OpSetID(11), OpSetID(10)));
348 registerAdapter(make_unique<CompatibleAdapter>("Conv", OpSetID(11), OpSetID(10)));
349 registerAdapter(make_unique<CompatibleAdapter>("ConvTranspose", OpSetID(11), OpSetID(10)));
350 registerAdapter(make_unique<TypeRestriction>("Equal", OpSetID(11), OpSetID(10), equal_unallowed_types));
351 registerAdapter(make_unique<CompatibleAdapter>("Flatten", OpSetID(11), OpSetID(10)));
352 registerAdapter(make_unique<CompatibleAdapter>("LogSoftmax", OpSetID(11), OpSetID(10)));
353 registerAdapter(make_unique<CompatibleAdapter>("MaxPool", OpSetID(11), OpSetID(10)));
354 registerAdapter(make_unique<CompatibleAdapter>("ReduceL1", OpSetID(11), OpSetID(10)));
355 registerAdapter(make_unique<CompatibleAdapter>("ReduceL2", OpSetID(11), OpSetID(10)));
356 registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSum", OpSetID(11), OpSetID(10)));
357 registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSumExp", OpSetID(11), OpSetID(10)));
358 registerAdapter(make_unique<CompatibleAdapter>("ReduceMax", OpSetID(11), OpSetID(10)));
359 registerAdapter(make_unique<CompatibleAdapter>("ReduceMean", OpSetID(11), OpSetID(10)));
360 registerAdapter(make_unique<CompatibleAdapter>("ReduceMin", OpSetID(11), OpSetID(10)));
361 registerAdapter(make_unique<CompatibleAdapter>("ReduceProd", OpSetID(11), OpSetID(10)));
362 registerAdapter(make_unique<CompatibleAdapter>("ReduceSum", OpSetID(11), OpSetID(10)));
363 registerAdapter(make_unique<CompatibleAdapter>("ReduceSumSquare", OpSetID(11), OpSetID(10)));
364 registerAdapter(make_unique<CompatibleAdapter>("Softmax", OpSetID(11), OpSetID(10)));
365 registerAdapter(make_unique<CompatibleAdapter>("Unsqueeze", OpSetID(11), OpSetID(10)));
366
367 /******** 11 -> 12 ********/
368 registerAdapter(make_unique<CompatibleAdapter>("ArgMax", OpSetID(11), OpSetID(12)));
369 registerAdapter(make_unique<CompatibleAdapter>("ArgMin", OpSetID(11), OpSetID(12)));
370 registerAdapter(make_unique<CompatibleAdapter>("BatchNormalization", OpSetID(11), OpSetID(12)));
371 registerAdapter(make_unique<CompatibleAdapter>("Constant", OpSetID(11), OpSetID(12)));
372 registerAdapter(make_unique<CompatibleAdapter>("Clip", OpSetID(11), OpSetID(12)));
373 registerAdapter(make_unique<CompatibleAdapter>("GatherND", OpSetID(11), OpSetID(12)));
374 registerAdapter(make_unique<CompatibleAdapter>("Min", OpSetID(11), OpSetID(12)));
375 registerAdapter(make_unique<CompatibleAdapter>("Max", OpSetID(11), OpSetID(12)));
376 registerAdapter(make_unique<CompatibleAdapter>("MaxPool", OpSetID(11), OpSetID(12)));
377 registerAdapter(make_unique<CompatibleAdapter>("Pow", OpSetID(11), OpSetID(12)));
378 registerAdapter(make_unique<CompatibleAdapter>("ReduceMax", OpSetID(11), OpSetID(12)));
379 registerAdapter(make_unique<CompatibleAdapter>("ReduceMin", OpSetID(11), OpSetID(12)));
380 registerAdapter(make_unique<Dropout_11_12>());
381
382 /******** 12 -> 11 ********/
383 std::vector<TensorProto_DataType> maxpool_unallowed_types = {TensorProto_DataType_UINT8, TensorProto_DataType_INT8};
384 registerAdapter("ArgMax", 12, 11, RemoveAttribute(kselect_last_index, 0));
385 registerAdapter("ArgMin", 12, 11, RemoveAttribute(kselect_last_index, 0));
386 registerAdapter(make_unique<CompatibleAdapter>("BatchNormalization", OpSetID(12), OpSetID(11)));
387 registerAdapter(make_unique<TypeRestriction>("Clip", OpSetID(12), OpSetID(11), int_unallowed_types));
388 registerAdapter(make_unique<TypeRestriction>("Min", OpSetID(12), OpSetID(11), int_unallowed_types));
389 registerAdapter(make_unique<TypeRestriction>("Max", OpSetID(12), OpSetID(11), int_unallowed_types));
390 registerAdapter(make_unique<TypeRestriction>("MaxPool", OpSetID(12), OpSetID(11), maxpool_unallowed_types));
391 registerAdapter(make_unique<TypeRestriction>("ReduceMax", OpSetID(12), OpSetID(11), maxpool_unallowed_types));
392 registerAdapter(make_unique<TypeRestriction>("ReduceMin", OpSetID(12), OpSetID(11), maxpool_unallowed_types));
393
394 /******** 12 -> 13 ********/
395 registerAdapter(make_unique<CompatibleAdapter>("Abs", OpSetID(12), OpSetID(13)));
396 registerAdapter(make_unique<CompatibleAdapter>("Add", OpSetID(12), OpSetID(13)));
397 registerAdapter(make_unique<CompatibleAdapter>("ArgMin", OpSetID(12), OpSetID(13)));
398 registerAdapter(make_unique<CompatibleAdapter>("ArgMax", OpSetID(12), OpSetID(13)));
399 registerAdapter(make_unique<CompatibleAdapter>("Cast", OpSetID(12), OpSetID(13)));
400 registerAdapter(make_unique<CompatibleAdapter>("Ceil", OpSetID(12), OpSetID(13)));
401 registerAdapter(make_unique<CompatibleAdapter>("Clip", OpSetID(12), OpSetID(13)));
402 registerAdapter(make_unique<CompatibleAdapter>("Concat", OpSetID(12), OpSetID(13)));
403 registerAdapter(make_unique<CompatibleAdapter>("Constant", OpSetID(12), OpSetID(13)));
404 registerAdapter(make_unique<CompatibleAdapter>("DepthToSpace", OpSetID(12), OpSetID(13)));
405 registerAdapter(make_unique<CompatibleAdapter>("DequantizeLinear", OpSetID(12), OpSetID(13)));
406 registerAdapter(make_unique<CompatibleAdapter>("Div", OpSetID(12), OpSetID(13)));
407 registerAdapter(make_unique<CompatibleAdapter>("Dropout", OpSetID(12), OpSetID(13)));
408 registerAdapter(make_unique<CompatibleAdapter>("Equal", OpSetID(12), OpSetID(13)));
409 registerAdapter(make_unique<CompatibleAdapter>("Erf", OpSetID(12), OpSetID(13)));
410 registerAdapter(make_unique<CompatibleAdapter>("Exp", OpSetID(12), OpSetID(13)));
411 registerAdapter(make_unique<CompatibleAdapter>("Expand", OpSetID(12), OpSetID(13)));
412 registerAdapter(make_unique<CompatibleAdapter>("Flatten", OpSetID(12), OpSetID(13)));
413 registerAdapter(make_unique<CompatibleAdapter>("Floor", OpSetID(12), OpSetID(13)));
414 registerAdapter(make_unique<CompatibleAdapter>("Gather", OpSetID(12), OpSetID(13)));
415 registerAdapter(make_unique<CompatibleAdapter>("GatherElements", OpSetID(12), OpSetID(13)));
416 registerAdapter(make_unique<CompatibleAdapter>("GatherND", OpSetID(12), OpSetID(13)));
417 registerAdapter(make_unique<CompatibleAdapter>("Gemm", OpSetID(12), OpSetID(13)));
418 registerAdapter(make_unique<CompatibleAdapter>("Greater", OpSetID(12), OpSetID(13)));
419 registerAdapter(make_unique<CompatibleAdapter>("Hardmax", OpSetID(12), OpSetID(13)));
420 registerAdapter(make_unique<CompatibleAdapter>("Identity", OpSetID(12), OpSetID(13)));
421 registerAdapter(make_unique<CompatibleAdapter>("If", OpSetID(12), OpSetID(13)));
422 registerAdapter(make_unique<CompatibleAdapter>("IsNaN", OpSetID(12), OpSetID(13)));
423 registerAdapter(make_unique<CompatibleAdapter>("Less", OpSetID(12), OpSetID(13)));
424 registerAdapter(make_unique<CompatibleAdapter>("Log", OpSetID(12), OpSetID(13)));
425 registerAdapter(make_unique<CompatibleAdapter>("Loop", OpSetID(12), OpSetID(13)));
426 registerAdapter(make_unique<CompatibleAdapter>("LRN", OpSetID(12), OpSetID(13)));
427 registerAdapter(make_unique<CompatibleAdapter>("NegativeLogLikelihoodLoss", OpSetID(12), OpSetID(13)));
428 registerAdapter(make_unique<CompatibleAdapter>("MatMul", OpSetID(12), OpSetID(13)));
429 registerAdapter(make_unique<CompatibleAdapter>("Max", OpSetID(12), OpSetID(13)));
430 registerAdapter(make_unique<CompatibleAdapter>("Mean", OpSetID(12), OpSetID(13)));
431 registerAdapter(make_unique<CompatibleAdapter>("MeanVarianceNormalization", OpSetID(12), OpSetID(13)));
432 registerAdapter(make_unique<CompatibleAdapter>("Min", OpSetID(12), OpSetID(13)));
433 registerAdapter(make_unique<CompatibleAdapter>("Mod", OpSetID(12), OpSetID(13)));
434 registerAdapter(make_unique<CompatibleAdapter>("Mul", OpSetID(12), OpSetID(13)));
435 registerAdapter(make_unique<CompatibleAdapter>("Neg", OpSetID(12), OpSetID(13)));
436 registerAdapter(make_unique<CompatibleAdapter>("NonZero", OpSetID(12), OpSetID(13)));
437 registerAdapter(make_unique<CompatibleAdapter>("Pow", OpSetID(12), OpSetID(13)));
438 registerAdapter(make_unique<CompatibleAdapter>("Pad", OpSetID(12), OpSetID(13)));
439 registerAdapter(make_unique<CompatibleAdapter>("QuantizeLinear", OpSetID(12), OpSetID(13)));
440 registerAdapter(make_unique<CompatibleAdapter>("Reciprocal", OpSetID(12), OpSetID(13)));
441 registerAdapter(make_unique<CompatibleAdapter>("ReduceL1", OpSetID(12), OpSetID(13)));
442 registerAdapter(make_unique<CompatibleAdapter>("ReduceL2", OpSetID(12), OpSetID(13)));
443 registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSum", OpSetID(12), OpSetID(13)));
444 registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSumExp", OpSetID(12), OpSetID(13)));
445 registerAdapter(make_unique<CompatibleAdapter>("ReduceMean", OpSetID(12), OpSetID(13)));
446 registerAdapter(make_unique<CompatibleAdapter>("ReduceMax", OpSetID(12), OpSetID(13)));
447 registerAdapter(make_unique<CompatibleAdapter>("ReduceMin", OpSetID(12), OpSetID(13)));
448 registerAdapter(make_unique<CompatibleAdapter>("ReduceProd", OpSetID(12), OpSetID(13)));
449 registerAdapter(make_unique<CompatibleAdapter>("ReduceSumSquare", OpSetID(12), OpSetID(13)));
450 registerAdapter(make_unique<CompatibleAdapter>("Relu", OpSetID(12), OpSetID(13)));
451 registerAdapter(make_unique<CompatibleAdapter>("Reshape", OpSetID(12), OpSetID(13)));
452 registerAdapter(make_unique<CompatibleAdapter>("Resize", OpSetID(12), OpSetID(13)));
453 registerAdapter(make_unique<CompatibleAdapter>("ScatterElements", OpSetID(12), OpSetID(13)));
454 registerAdapter(make_unique<CompatibleAdapter>("ScatterND", OpSetID(12), OpSetID(13)));
455 registerAdapter(make_unique<CompatibleAdapter>("Shape", OpSetID(12), OpSetID(13)));
456 registerAdapter(make_unique<CompatibleAdapter>("Sigmoid", OpSetID(12), OpSetID(13)));
457 registerAdapter(make_unique<CompatibleAdapter>("Sign", OpSetID(12), OpSetID(13)));
458 registerAdapter(make_unique<CompatibleAdapter>("Size", OpSetID(12), OpSetID(13)));
459 registerAdapter(make_unique<CompatibleAdapter>("Slice", OpSetID(12), OpSetID(13)));
460 registerAdapter(make_unique<CompatibleAdapter>("SoftmaxCrossEntropyLoss", OpSetID(12), OpSetID(13)));
461 registerAdapter(make_unique<CompatibleAdapter>("SpaceToDepth", OpSetID(12), OpSetID(13)));
462 registerAdapter(make_unique<CompatibleAdapter>("Sqrt", OpSetID(12), OpSetID(13)));
463 registerAdapter(make_unique<CompatibleAdapter>("Sub", OpSetID(12), OpSetID(13)));
464 registerAdapter(make_unique<CompatibleAdapter>("Sum", OpSetID(12), OpSetID(13)));
465 registerAdapter(make_unique<CompatibleAdapter>("Tanh", OpSetID(12), OpSetID(13)));
466 registerAdapter(make_unique<CompatibleAdapter>("Tile", OpSetID(12), OpSetID(13)));
467 registerAdapter(make_unique<CompatibleAdapter>("Transpose", OpSetID(12), OpSetID(13)));
468 registerAdapter(make_unique<AxesAttributeToInput>("ReduceSum", OpSetID(12), OpSetID(13)));
469 registerAdapter(make_unique<AxesAttributeToInput>("Squeeze", OpSetID(12), OpSetID(13)));
470 registerAdapter(make_unique<AxesAttributeToInput>("Unsqueeze", OpSetID(12), OpSetID(13)));
471 registerAdapter(make_unique<Split_12_13>());
472 registerAdapter(make_unique<Softmax_12_13>("Softmax"));
473 registerAdapter(make_unique<Softmax_12_13>("LogSoftmax"));
474
475 /******** 13 -> 12 ********/
476 registerAdapter(make_unique<CompatibleAdapter>("Constant", OpSetID(13), OpSetID(12)));
477 registerAdapter(make_unique<AxesInputToAttribute>("ReduceSum", OpSetID(13), OpSetID(12)));
478 registerAdapter(make_unique<AxesInputToAttribute>("Squeeze", OpSetID(13), OpSetID(12)));
479 registerAdapter(make_unique<AxesInputToAttribute>("Unsqueeze", OpSetID(13), OpSetID(12)));
480 registerAdapter(make_unique<Split_13_12>());
481
482 /******** 13 -> 14 ********/
483 registerAdapter(make_unique<CompatibleAdapter>("Add", OpSetID(13), OpSetID(14)));
484 registerAdapter(make_unique<CompatibleAdapter>("CumSum", OpSetID(13), OpSetID(14)));
485 registerAdapter(make_unique<CompatibleAdapter>("Div", OpSetID(13), OpSetID(14)));
486 registerAdapter(make_unique<CompatibleAdapter>("Identity", OpSetID(13), OpSetID(14)));
487 registerAdapter(make_unique<CompatibleAdapter>("Mul", OpSetID(13), OpSetID(14)));
488 registerAdapter(make_unique<CompatibleAdapter>("Relu", OpSetID(13), OpSetID(14)));
489 registerAdapter(make_unique<CompatibleAdapter>("Reshape", OpSetID(13), OpSetID(14)));
490 registerAdapter(make_unique<CompatibleAdapter>("Sub", OpSetID(13), OpSetID(14)));
491 registerAdapter("GRU", 13, 14, SetAttribute(klayout, 0));
492 registerAdapter("LSTM", 13, 14, SetAttribute(klayout, 0));
493 registerAdapter("RNN", 13, 14, SetAttribute(klayout, 0));
494 registerAdapter(make_unique<BatchNormalization_13_14>());
495
496 /******** 14 -> 13 ********/
497 registerAdapter("GRU", 14, 13, RemoveAttribute(klayout, 0));
498 registerAdapter("LSTM", 14, 13, RemoveAttribute(klayout, 0));
499 registerAdapter("RNN", 14, 13, RemoveAttribute(klayout, 0));
500
501 /******** 14 -> 15 ********/
502 registerAdapter(make_unique<CompatibleAdapter>("BatchNormalization", OpSetID(14), OpSetID(15)));
503 registerAdapter(make_unique<CompatibleAdapter>("Pow", OpSetID(14), OpSetID(15)));
504 registerAdapter(make_unique<CompatibleAdapter>("Shape", OpSetID(14), OpSetID(15)));
505
506 /******** 15 -> 16 ********/
507 registerAdapter("RoiAlign", 15, 16, SetAttribute(kcoordinate_transformation_mode, "output_half_pixel"));
508 registerAdapter(make_unique<CompatibleAdapter>("ScatterElements", OpSetID(15), OpSetID(16)));
509 registerAdapter(make_unique<CompatibleAdapter>("ScatterND", OpSetID(15), OpSetID(16)));
510 registerAdapter(make_unique<CompatibleAdapter>("Identity", OpSetID(15), OpSetID(16)));
511 registerAdapter(make_unique<CompatibleAdapter>("Loop", OpSetID(15), OpSetID(16)));
512 registerAdapter(make_unique<CompatibleAdapter>("If", OpSetID(15), OpSetID(16)));
513 registerAdapter(make_unique<CompatibleAdapter>("Where", OpSetID(15), OpSetID(16)));
514 registerAdapter(make_unique<CompatibleAdapter>("Scan", OpSetID(15), OpSetID(16)));
515 registerAdapter(make_unique<CompatibleAdapter>("LessOrEqual", OpSetID(15), OpSetID(16)));
516 registerAdapter(make_unique<CompatibleAdapter>("GreaterOrEqual", OpSetID(15), OpSetID(16)));
517 registerAdapter(make_unique<CompatibleAdapter>("LeakyRelu", OpSetID(15), OpSetID(16)));
518 registerAdapter(make_unique<CompatibleAdapter>("PRelu", OpSetID(15), OpSetID(16)));
519
520 /******** 17 -> 18 ********/
521 registerAdapter(make_unique<CompatibleAdapter>("Pad", OpSetID(17), OpSetID(18)));
522 registerAdapter(make_unique<CompatibleAdapter>("Resize", OpSetID(17), OpSetID(18)));
523 registerAdapter(make_unique<CompatibleAdapter>("OptionalGetElement", OpSetID(17), OpSetID(18)));
524 registerAdapter(make_unique<CompatibleAdapter>("OptionalHasElement", OpSetID(17), OpSetID(18)));
525 registerAdapter(make_unique<Split_17_18>());
526 registerAdapter(make_unique<CompatibleAdapter>("ScatterND", OpSetID(17), OpSetID(18)));
527 registerAdapter(make_unique<CompatibleAdapter>("ScatterElements", OpSetID(17), OpSetID(18)));
528 registerAdapter("LpPool", 17, 18, SetAttribute(kceil_mode, 0));
529 registerAdapter(make_unique<AxesAttributeToInput>("ReduceL1", OpSetID(17), OpSetID(18)));
530 registerAdapter(make_unique<AxesAttributeToInput>("ReduceL2", OpSetID(17), OpSetID(18)));
531 registerAdapter(make_unique<AxesAttributeToInput>("ReduceLogSum", OpSetID(17), OpSetID(18)));
532 registerAdapter(make_unique<AxesAttributeToInput>("ReduceLogSumExp", OpSetID(17), OpSetID(18)));
533 registerAdapter(make_unique<AxesAttributeToInput>("ReduceMax", OpSetID(17), OpSetID(18)));
534 registerAdapter(make_unique<AxesAttributeToInput>("ReduceMean", OpSetID(17), OpSetID(18)));
535 registerAdapter(make_unique<AxesAttributeToInput>("ReduceMin", OpSetID(17), OpSetID(18)));
536 registerAdapter(make_unique<AxesAttributeToInput>("ReduceProd", OpSetID(17), OpSetID(18)));
537 registerAdapter(make_unique<AxesAttributeToInput>("ReduceSumSquare", OpSetID(17), OpSetID(18)));
538 }
539
540 ModelProto convert_version(const ModelProto& mp_in, const OpSetID& initial_version, const OpSetID& target_version)
541 const override;
542};
543
544ModelProto ConvertVersion(const ModelProto& mp_in, int target_version);
545} // namespace version_conversion
546} // namespace ONNX_NAMESPACE
547