1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // Default converter for ONNX models between different opset versions |
6 | // in the default domain ("" or "ai.onnx"). |
7 | |
8 | #pragma once |
9 | |
10 | #include "onnx/version_converter/BaseConverter.h" |
11 | #include "onnx/version_converter/adapters/axes_attribute_to_input.h" |
12 | #include "onnx/version_converter/adapters/axes_input_to_attribute.h" |
13 | #include "onnx/version_converter/adapters/batch_normalization_13_14.h" |
14 | #include "onnx/version_converter/adapters/broadcast_backward_compatibility.h" |
15 | #include "onnx/version_converter/adapters/broadcast_forward_compatibility.h" |
16 | #include "onnx/version_converter/adapters/cast_9_8.h" |
17 | #include "onnx/version_converter/adapters/clip_10_11.h" |
18 | #include "onnx/version_converter/adapters/compatible.h" |
19 | #include "onnx/version_converter/adapters/dropout_11_12.h" |
20 | #include "onnx/version_converter/adapters/extend_supported_types.h" |
21 | #include "onnx/version_converter/adapters/gemm_6_7.h" |
22 | #include "onnx/version_converter/adapters/gemm_7_6.h" |
23 | #include "onnx/version_converter/adapters/maxpool_8_7.h" |
24 | #include "onnx/version_converter/adapters/no_previous_version.h" |
25 | #include "onnx/version_converter/adapters/pad_10_11.h" |
26 | #include "onnx/version_converter/adapters/reshape_4_5.h" |
27 | #include "onnx/version_converter/adapters/reshape_5_4.h" |
28 | #include "onnx/version_converter/adapters/resize_10_11.h" |
29 | #include "onnx/version_converter/adapters/scan_8_9.h" |
30 | #include "onnx/version_converter/adapters/scan_9_8.h" |
31 | #include "onnx/version_converter/adapters/scatter_10_11.h" |
32 | #include "onnx/version_converter/adapters/slice_9_10.h" |
33 | #include "onnx/version_converter/adapters/softmax_12_13.h" |
34 | #include "onnx/version_converter/adapters/split_12_13.h" |
35 | #include "onnx/version_converter/adapters/split_13_12.h" |
36 | #include "onnx/version_converter/adapters/split_17_18.h" |
37 | #include "onnx/version_converter/adapters/sum_8_7.h" |
38 | #include "onnx/version_converter/adapters/topk_9_10.h" |
39 | #include "onnx/version_converter/adapters/type_restriction.h" |
40 | #include "onnx/version_converter/adapters/upsample_6_7.h" |
41 | #include "onnx/version_converter/adapters/upsample_8_9.h" |
42 | #include "onnx/version_converter/adapters/upsample_9_10.h" |
43 | #include "onnx/version_converter/adapters/upsample_9_8.h" |
44 | |
45 | #include "onnx/version_converter/adapters/transformers.h" |
46 | |
47 | namespace ONNX_NAMESPACE { |
48 | namespace version_conversion { |
49 | |
50 | class DefaultVersionConverter : public BaseVersionConverter { |
51 | private: |
52 | bool DEBUG = false; |
53 | |
54 | std::pair<int, int> version_range; |
55 | |
56 | bool searchOpDomainMap( |
57 | const std::unordered_map<std::string, std::map<int64_t, const OpSchema*>>& op_domain_map, |
58 | int64_t curr_version, |
59 | int64_t step) const { |
60 | bool up = step == 1; |
61 | const auto version_it = op_domain_map.find("" ); |
62 | return version_it != op_domain_map.end() && |
63 | ((version_it->second.find(curr_version) != version_it->second.end() && !up) || |
64 | (version_it->second.find(curr_version + step) != version_it->second.end() && up)); |
65 | } |
66 | |
67 | void debug(const std::string& str) const { |
68 | if (DEBUG) |
69 | std::cerr << str << std::endl; |
70 | } |
71 | |
72 | void assertInVersionRange(int64_t version) const { |
73 | ONNX_ASSERTM( |
74 | version >= version_range.first && version <= version_range.second, |
75 | "Warning: invalid version (must be between %d and %d)" , |
76 | version_range.first, |
77 | version_range.second); |
78 | } |
79 | |
80 | void assertDefaultDomain(const std::string& initial_domain, const std::string& target_domain) const { |
81 | ONNX_ASSERTM( |
82 | (initial_domain == "" || initial_domain == "ai.onnx" ) && (target_domain == "" || target_domain == "ai.onnx" ), |
83 | "Warning: default onnx version converter can only convert " |
84 | " between default domain opset versions ('' or 'ai.onnx')\n" ); |
85 | ONNX_ASSERTM(initial_domain == target_domain, "initial_version and target_version must have the same domains" ); |
86 | } |
87 | |
88 | void convert_graph(std::shared_ptr<Graph> g, const OpSetID& initial_version, const OpSetID& target_version) const; |
89 | |
90 | public: |
91 | DefaultVersionConverter() { |
92 | const std::unordered_map<std::string, std::pair<int, int>>& versions_map = |
93 | OpSchemaRegistry::DomainToVersionRange::Instance().Map(); |
94 | version_range = versions_map.at("" ); |
95 | // Register adapters to the version converter |
96 | const std::vector<OpSchema> all_opschemas = OpSchemaRegistry::get_all_schemas_with_history(); |
97 | |
98 | for (const OpSchema& schema : all_opschemas) { |
99 | all_schemas[schema.Name()][schema.domain()][(int64_t)schema.since_version()] = &schema; |
100 | } |
101 | |
102 | // Iterate through all_schemas to determine NoPreviousVersionAdapters |
103 | for (auto& op_pair : all_schemas) { |
104 | const auto default_versions = op_pair.second.find("" ); |
105 | if (default_versions != op_pair.second.end()) { |
106 | int64_t min_version = version_range.second; |
107 | for (auto& version_pair : default_versions->second) { |
108 | if (version_pair.first < min_version) { |
109 | min_version = version_pair.first; |
110 | } |
111 | } |
112 | if (min_version > 1) { |
113 | registerAdapter( |
114 | make_unique<NoPreviousVersionAdapter>(op_pair.first, OpSetID(min_version), OpSetID(min_version - 1))); |
115 | } |
116 | } |
117 | } |
118 | |
119 | /******** 1 -> 2 ********/ |
120 | // Missing in this group: GlobalLpPool, LpPool, Pad, Split |
121 | |
122 | /******** 2 -> 3 ********/ |
123 | // Missing in this group: GRU |
124 | |
125 | /******** 3 -> 4 ********/ |
126 | registerAdapter("Concat" , 3, 4, SetAttributeIfAbsent(kaxis, 1)); |
127 | |
128 | /******** 4 -> 3 ********/ |
129 | std::vector<TensorProto_DataType> concat_unallowed_types = { |
130 | TensorProto_DataType_INT32, |
131 | TensorProto_DataType_INT64, |
132 | TensorProto_DataType_UINT32, |
133 | TensorProto_DataType_UINT64, |
134 | TensorProto_DataType_UINT8, |
135 | TensorProto_DataType_UINT16, |
136 | TensorProto_DataType_INT8, |
137 | TensorProto_DataType_INT16, |
138 | TensorProto_DataType_STRING, |
139 | TensorProto_DataType_BOOL}; |
140 | registerAdapter(make_unique<TypeRestriction>("Concat" , OpSetID(4), OpSetID(3), concat_unallowed_types)); |
141 | |
142 | /******** 4 -> 5 ********/ |
143 | registerAdapter(make_unique<Reshape_4_5>()); |
144 | |
145 | /******** 5 -> 4 ********/ |
146 | registerAdapter(make_unique<Reshape_5_4>()); |
147 | |
148 | /******** 5 -> 6 ********/ |
149 | // Missing in this group: Cast, Tile |
150 | auto removeConsumedInputs = RemoveAttribute(kconsumed_inputs); |
151 | registerAdapter("Add" , 5, 6, removeConsumedInputs); |
152 | registerAdapter("Mul" , 5, 6, removeConsumedInputs); |
153 | registerAdapter(make_unique<CompatibleAdapter>("Gemm" , OpSetID(5), OpSetID(6))); |
154 | registerAdapter("Relu" , 5, 6, removeConsumedInputs); |
155 | registerAdapter("BatchNormalization" , 5, 6, removeConsumedInputs); |
156 | registerAdapter("Sum" , 5, 6, removeConsumedInputs); |
157 | registerAdapter("Dropout" , 5, 6, removeConsumedInputs); |
158 | registerAdapter("Abs" , 5, 6, removeConsumedInputs); |
159 | registerAdapter("Ceil" , 5, 6, removeConsumedInputs); |
160 | registerAdapter("Clip" , 5, 6, removeConsumedInputs); |
161 | registerAdapter("Div" , 5, 6, removeConsumedInputs); |
162 | registerAdapter("Elu" , 5, 6, removeConsumedInputs); |
163 | registerAdapter("Exp" , 5, 6, removeConsumedInputs); |
164 | registerAdapter("Floor" , 5, 6, removeConsumedInputs); |
165 | registerAdapter("HardSigmoid" , 5, 6, removeConsumedInputs); |
166 | registerAdapter("InstanceNormalization" , 5, 6, removeConsumedInputs); |
167 | registerAdapter("LeakyRelu" , 5, 6, removeConsumedInputs); |
168 | registerAdapter("Log" , 5, 6, removeConsumedInputs); |
169 | registerAdapter("Max" , 5, 6, removeConsumedInputs); |
170 | registerAdapter("Mean" , 5, 6, removeConsumedInputs); |
171 | registerAdapter("Min" , 5, 6, removeConsumedInputs); |
172 | registerAdapter("Neg" , 5, 6, removeConsumedInputs); |
173 | registerAdapter("PRelu" , 5, 6, removeConsumedInputs); |
174 | registerAdapter("Reciprocal" , 5, 6, removeConsumedInputs); |
175 | registerAdapter("Selu" , 5, 6, removeConsumedInputs); |
176 | registerAdapter("Sigmoid" , 5, 6, removeConsumedInputs); |
177 | registerAdapter("Sqrt" , 5, 6, removeConsumedInputs); |
178 | registerAdapter("Sub" , 5, 6, removeConsumedInputs); |
179 | registerAdapter("Tanh" , 5, 6, removeConsumedInputs); |
180 | |
181 | /******** 6 -> 5 ********/ |
182 | std::vector<TensorProto_DataType> broadcast_unallowed_types = { |
183 | TensorProto_DataType_INT32, |
184 | TensorProto_DataType_INT64, |
185 | TensorProto_DataType_UINT32, |
186 | TensorProto_DataType_UINT64}; |
187 | std::vector<TensorProto_DataType> int_unallowed_types = { |
188 | TensorProto_DataType_UINT8, |
189 | TensorProto_DataType_UINT16, |
190 | TensorProto_DataType_UINT32, |
191 | TensorProto_DataType_UINT64, |
192 | TensorProto_DataType_INT8, |
193 | TensorProto_DataType_INT16, |
194 | TensorProto_DataType_INT32, |
195 | TensorProto_DataType_INT64}; |
196 | std::vector<TensorProto_DataType> neg_unallowed_types = { |
197 | TensorProto_DataType_INT32, TensorProto_DataType_INT8, TensorProto_DataType_UINT16, TensorProto_DataType_INT64}; |
198 | registerAdapter(make_unique<TypeRestriction>("Add" , OpSetID(6), OpSetID(5), broadcast_unallowed_types)); |
199 | registerAdapter(make_unique<TypeRestriction>("Mul" , OpSetID(6), OpSetID(5), broadcast_unallowed_types)); |
200 | registerAdapter(make_unique<TypeRestriction>("Sub" , OpSetID(6), OpSetID(5), broadcast_unallowed_types)); |
201 | registerAdapter(make_unique<TypeRestriction>("Div" , OpSetID(6), OpSetID(5), broadcast_unallowed_types)); |
202 | registerAdapter(make_unique<TypeRestriction>("Abs" , OpSetID(6), OpSetID(5), int_unallowed_types)); |
203 | registerAdapter(make_unique<TypeRestriction>("Neg" , OpSetID(6), OpSetID(5), neg_unallowed_types)); |
204 | registerAdapter("BatchNormalization" , 6, 5, SetAttribute(kconsumed_inputs, std::vector<int64_t>({0, 0}))); |
205 | registerAdapter(make_unique<CompatibleAdapter>("Gemm" , OpSetID(6), OpSetID(5))); |
206 | registerAdapter(make_unique<CompatibleAdapter>("Relu" , OpSetID(6), OpSetID(5))); |
207 | registerAdapter(make_unique<CompatibleAdapter>("Sum" , OpSetID(6), OpSetID(5))); |
208 | registerAdapter(make_unique<CompatibleAdapter>("Dropout" , OpSetID(6), OpSetID(5))); |
209 | |
210 | /******** 6 -> 7 ********/ |
211 | // Missing in this group: And, Equal, Greater, GRU, Less, LSTM, Or, RNN, Upsample, Xor |
212 | registerAdapter(make_unique<BroadcastForwardCompatibility>("Add" , OpSetID(6), OpSetID(7))); |
213 | registerAdapter(make_unique<CompatibleAdapter>("AveragePool" , OpSetID(6), OpSetID(7))); |
214 | registerAdapter(make_unique<BroadcastForwardCompatibility>("Div" , OpSetID(6), OpSetID(7))); |
215 | registerAdapter(make_unique<BroadcastForwardCompatibility>("Mul" , OpSetID(6), OpSetID(7))); |
216 | registerAdapter(make_unique<BroadcastForwardCompatibility>("Pow" , OpSetID(6), OpSetID(7))); |
217 | registerAdapter(make_unique<CompatibleAdapter>("PRelu" , OpSetID(6), OpSetID(7))); |
218 | registerAdapter(make_unique<BroadcastForwardCompatibility>("Sub" , OpSetID(6), OpSetID(7))); |
219 | registerAdapter(make_unique<Gemm_6_7>()); |
220 | registerAdapter("BatchNormalization" , 6, 7, RemoveAttributeNotEq(kis_test, 0)); |
221 | registerAdapter("Dropout" , 6, 7, RemoveAttributeNotEq(kis_test, 0)); |
222 | registerAdapter(make_unique<Upsample_6_7>()); |
223 | |
224 | /******** 7 -> 6 ********/ |
225 | registerAdapter(make_unique<BroadcastBackwardCompatibility>("Add" , OpSetID(7), OpSetID(6))); |
226 | registerAdapter(make_unique<BroadcastBackwardCompatibility>("Div" , OpSetID(7), OpSetID(6))); |
227 | registerAdapter(make_unique<BroadcastBackwardCompatibility>("Mul" , OpSetID(7), OpSetID(6))); |
228 | registerAdapter(make_unique<BroadcastBackwardCompatibility>("Pow" , OpSetID(7), OpSetID(6))); |
229 | registerAdapter(make_unique<CompatibleAdapter>("PRelu" , OpSetID(7), OpSetID(6))); |
230 | registerAdapter(make_unique<BroadcastBackwardCompatibility>("Sub" , OpSetID(7), OpSetID(6))); |
231 | registerAdapter("BatchNormalization" , 7, 6, SetAttribute(kis_test, 1)); |
232 | registerAdapter("Dropout" , 7, 6, SetAttribute(kis_test, 1)); |
233 | registerAdapter(make_unique<Gemm_7_6>()); |
234 | registerAdapter("AveragePool" , 7, 6, RemoveAttribute(kcount_include_pad, 0)); |
235 | |
236 | /******** 7 -> 8 ********/ |
237 | registerAdapter(make_unique<CompatibleAdapter>("Max" , OpSetID(7), OpSetID(8))); |
238 | registerAdapter(make_unique<CompatibleAdapter>("Min" , OpSetID(7), OpSetID(8))); |
239 | registerAdapter(make_unique<CompatibleAdapter>("Mean" , OpSetID(7), OpSetID(8))); |
240 | registerAdapter(make_unique<CompatibleAdapter>("Sum" , OpSetID(7), OpSetID(8))); |
241 | registerAdapter(make_unique<CompatibleAdapter>("MaxPool" , OpSetID(7), OpSetID(8))); |
242 | |
243 | /******** 8 -> 7 ********/ |
244 | registerAdapter(make_unique<BroadcastBackwardCompatibility>("Max" , OpSetID(8), OpSetID(7))); |
245 | registerAdapter(make_unique<BroadcastBackwardCompatibility>("Min" , OpSetID(8), OpSetID(7))); |
246 | registerAdapter(make_unique<BroadcastBackwardCompatibility>("Mean" , OpSetID(8), OpSetID(7))); |
247 | registerAdapter(make_unique<Sum_8_7>()); |
248 | registerAdapter(make_unique<MaxPool_8_7>()); |
249 | |
250 | /******** 8 -> 9 ********/ |
251 | registerAdapter(make_unique<CompatibleAdapter>("Flatten" , OpSetID(8), OpSetID(9))); |
252 | registerAdapter(make_unique<CompatibleAdapter>("Constant" , OpSetID(8), OpSetID(9))); |
253 | registerAdapter(make_unique<CompatibleAdapter>("MatMul" , OpSetID(8), OpSetID(9))); |
254 | registerAdapter(make_unique<CompatibleAdapter>("Gemm" , OpSetID(8), OpSetID(9))); |
255 | registerAdapter(make_unique<CompatibleAdapter>("PRelu" , OpSetID(8), OpSetID(9))); |
256 | registerAdapter(make_unique<CompatibleAdapter>("Greater" , OpSetID(8), OpSetID(9))); |
257 | registerAdapter(make_unique<CompatibleAdapter>("Less" , OpSetID(8), OpSetID(9))); |
258 | registerAdapter(make_unique<CompatibleAdapter>("Cast" , OpSetID(8), OpSetID(9))); |
259 | registerAdapter("BatchNormalization" , 8, 9, RemoveAttribute(kspatial, 1)); |
260 | registerAdapter(make_unique<Scan_8_9>()); |
261 | registerAdapter(make_unique<Upsample_8_9>()); |
262 | |
263 | /******** 9 -> 8 ********/ |
264 | registerAdapter(make_unique<CompatibleAdapter>("BatchNormalization" , OpSetID(9), OpSetID(8))); |
265 | registerAdapter(make_unique<ExtendSupportedTypes>("Flatten" , OpSetID(9), OpSetID(8))); |
266 | registerAdapter(make_unique<ExtendSupportedTypes>("Constant" , OpSetID(9), OpSetID(8))); |
267 | registerAdapter(make_unique<ExtendSupportedTypes>("MatMul" , OpSetID(9), OpSetID(8))); |
268 | registerAdapter(make_unique<ExtendSupportedTypes>("Gemm" , OpSetID(9), OpSetID(8))); |
269 | registerAdapter(make_unique<ExtendSupportedTypes>("PRelu" , OpSetID(9), OpSetID(8))); |
270 | registerAdapter(make_unique<ExtendSupportedTypes>("Greater" , OpSetID(9), OpSetID(8))); |
271 | registerAdapter(make_unique<ExtendSupportedTypes>("Less" , OpSetID(9), OpSetID(8))); |
272 | registerAdapter(make_unique<Cast_9_8>()); |
273 | registerAdapter(make_unique<Scan_9_8>()); |
274 | registerAdapter(make_unique<Upsample_9_8>()); |
275 | |
276 | /******** 9 -> 10 ********/ |
277 | registerAdapter(make_unique<CompatibleAdapter>("AveragePool" , OpSetID(9), OpSetID(10))); |
278 | registerAdapter(make_unique<CompatibleAdapter>("MaxPool" , OpSetID(9), OpSetID(10))); |
279 | registerAdapter(make_unique<CompatibleAdapter>("Dropout" , OpSetID(9), OpSetID(10))); |
280 | registerAdapter(make_unique<Slice_9_10>()); |
281 | registerAdapter(make_unique<TopK_9_10>()); |
282 | registerAdapter(make_unique<Upsample_9_10>()); |
283 | |
284 | /******** 10 -> 9 ********/ |
285 | registerAdapter(make_unique<CompatibleAdapter>("Dropout" , OpSetID(10), OpSetID(9))); |
286 | |
287 | /******** 10 -> 11 ********/ |
288 | registerAdapter(make_unique<CompatibleAdapter>("ArgMax" , OpSetID(10), OpSetID(11))); |
289 | registerAdapter(make_unique<CompatibleAdapter>("ArgMin" , OpSetID(10), OpSetID(11))); |
290 | registerAdapter(make_unique<CompatibleAdapter>("AveragePool" , OpSetID(10), OpSetID(11))); |
291 | registerAdapter(make_unique<CompatibleAdapter>("Concat" , OpSetID(10), OpSetID(11))); |
292 | registerAdapter(make_unique<CompatibleAdapter>("Constant" , OpSetID(10), OpSetID(11))); |
293 | registerAdapter(make_unique<CompatibleAdapter>("Compress" , OpSetID(10), OpSetID(11))); |
294 | registerAdapter(make_unique<CompatibleAdapter>("Conv" , OpSetID(10), OpSetID(11))); |
295 | registerAdapter(make_unique<CompatibleAdapter>("ConvTranspose" , OpSetID(10), OpSetID(11))); |
296 | registerAdapter(make_unique<CompatibleAdapter>("DepthToSpace" , OpSetID(10), OpSetID(11))); |
297 | registerAdapter(make_unique<CompatibleAdapter>("Equal" , OpSetID(10), OpSetID(11))); |
298 | registerAdapter(make_unique<CompatibleAdapter>("Flatten" , OpSetID(10), OpSetID(11))); |
299 | registerAdapter(make_unique<CompatibleAdapter>("Gather" , OpSetID(10), OpSetID(11))); |
300 | registerAdapter(make_unique<CompatibleAdapter>("Gemm" , OpSetID(10), OpSetID(11))); |
301 | registerAdapter(make_unique<CompatibleAdapter>("Hardmax" , OpSetID(10), OpSetID(11))); |
302 | registerAdapter(make_unique<CompatibleAdapter>("If" , OpSetID(10), OpSetID(11))); |
303 | registerAdapter(make_unique<CompatibleAdapter>("LogSoftmax" , OpSetID(10), OpSetID(11))); |
304 | registerAdapter(make_unique<CompatibleAdapter>("Loop" , OpSetID(10), OpSetID(11))); |
305 | registerAdapter(make_unique<CompatibleAdapter>("LpPool" , OpSetID(10), OpSetID(11))); |
306 | registerAdapter(make_unique<CompatibleAdapter>("MaxPool" , OpSetID(10), OpSetID(11))); |
307 | registerAdapter(make_unique<CompatibleAdapter>("MaxUnpool" , OpSetID(10), OpSetID(11))); |
308 | registerAdapter(make_unique<CompatibleAdapter>("NonMaxSuppression" , OpSetID(10), OpSetID(11))); |
309 | registerAdapter(make_unique<CompatibleAdapter>("OneHot" , OpSetID(10), OpSetID(11))); |
310 | registerAdapter(make_unique<CompatibleAdapter>("ReduceL1" , OpSetID(10), OpSetID(11))); |
311 | registerAdapter(make_unique<CompatibleAdapter>("ReduceL2" , OpSetID(10), OpSetID(11))); |
312 | registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSum" , OpSetID(10), OpSetID(11))); |
313 | registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSumExp" , OpSetID(10), OpSetID(11))); |
314 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMax" , OpSetID(10), OpSetID(11))); |
315 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMean" , OpSetID(10), OpSetID(11))); |
316 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMin" , OpSetID(10), OpSetID(11))); |
317 | registerAdapter(make_unique<CompatibleAdapter>("ReduceProd" , OpSetID(10), OpSetID(11))); |
318 | registerAdapter(make_unique<CompatibleAdapter>("ReduceSum" , OpSetID(10), OpSetID(11))); |
319 | registerAdapter(make_unique<CompatibleAdapter>("ReduceSumSquare" , OpSetID(10), OpSetID(11))); |
320 | registerAdapter(make_unique<CompatibleAdapter>("Scan" , OpSetID(10), OpSetID(11))); |
321 | registerAdapter(make_unique<CompatibleAdapter>("Softmax" , OpSetID(10), OpSetID(11))); |
322 | registerAdapter(make_unique<CompatibleAdapter>("Slice" , OpSetID(10), OpSetID(11))); |
323 | registerAdapter(make_unique<CompatibleAdapter>("Split" , OpSetID(10), OpSetID(11))); |
324 | registerAdapter(make_unique<CompatibleAdapter>("Squeeze" , OpSetID(10), OpSetID(11))); |
325 | registerAdapter(make_unique<CompatibleAdapter>("TopK" , OpSetID(10), OpSetID(11))); |
326 | registerAdapter(make_unique<CompatibleAdapter>("Unsqueeze" , OpSetID(10), OpSetID(11))); |
327 | registerAdapter(make_unique<Clip_10_11>()); |
328 | registerAdapter(make_unique<Pad_10_11>()); |
329 | registerAdapter(make_unique<Resize_10_11>()); |
330 | registerAdapter(make_unique<Scatter_10_11>()); |
331 | |
332 | /******** 11 -> 10 ********/ |
333 | std::vector<TensorProto_DataType> equal_unallowed_types = { |
334 | TensorProto_DataType_UINT8, |
335 | TensorProto_DataType_UINT16, |
336 | TensorProto_DataType_UINT32, |
337 | TensorProto_DataType_UINT64, |
338 | TensorProto_DataType_INT8, |
339 | TensorProto_DataType_INT16, |
340 | TensorProto_DataType_FLOAT16, |
341 | TensorProto_DataType_FLOAT, |
342 | TensorProto_DataType_DOUBLE}; |
343 | registerAdapter(make_unique<CompatibleAdapter>("ArgMax" , OpSetID(11), OpSetID(10))); |
344 | registerAdapter(make_unique<CompatibleAdapter>("ArgMin" , OpSetID(11), OpSetID(10))); |
345 | registerAdapter(make_unique<CompatibleAdapter>("AveragePool" , OpSetID(11), OpSetID(10))); |
346 | registerAdapter(make_unique<CompatibleAdapter>("Concat" , OpSetID(11), OpSetID(10))); |
347 | registerAdapter(make_unique<CompatibleAdapter>("Constant" , OpSetID(11), OpSetID(10))); |
348 | registerAdapter(make_unique<CompatibleAdapter>("Conv" , OpSetID(11), OpSetID(10))); |
349 | registerAdapter(make_unique<CompatibleAdapter>("ConvTranspose" , OpSetID(11), OpSetID(10))); |
350 | registerAdapter(make_unique<TypeRestriction>("Equal" , OpSetID(11), OpSetID(10), equal_unallowed_types)); |
351 | registerAdapter(make_unique<CompatibleAdapter>("Flatten" , OpSetID(11), OpSetID(10))); |
352 | registerAdapter(make_unique<CompatibleAdapter>("LogSoftmax" , OpSetID(11), OpSetID(10))); |
353 | registerAdapter(make_unique<CompatibleAdapter>("MaxPool" , OpSetID(11), OpSetID(10))); |
354 | registerAdapter(make_unique<CompatibleAdapter>("ReduceL1" , OpSetID(11), OpSetID(10))); |
355 | registerAdapter(make_unique<CompatibleAdapter>("ReduceL2" , OpSetID(11), OpSetID(10))); |
356 | registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSum" , OpSetID(11), OpSetID(10))); |
357 | registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSumExp" , OpSetID(11), OpSetID(10))); |
358 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMax" , OpSetID(11), OpSetID(10))); |
359 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMean" , OpSetID(11), OpSetID(10))); |
360 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMin" , OpSetID(11), OpSetID(10))); |
361 | registerAdapter(make_unique<CompatibleAdapter>("ReduceProd" , OpSetID(11), OpSetID(10))); |
362 | registerAdapter(make_unique<CompatibleAdapter>("ReduceSum" , OpSetID(11), OpSetID(10))); |
363 | registerAdapter(make_unique<CompatibleAdapter>("ReduceSumSquare" , OpSetID(11), OpSetID(10))); |
364 | registerAdapter(make_unique<CompatibleAdapter>("Softmax" , OpSetID(11), OpSetID(10))); |
365 | registerAdapter(make_unique<CompatibleAdapter>("Unsqueeze" , OpSetID(11), OpSetID(10))); |
366 | |
367 | /******** 11 -> 12 ********/ |
368 | registerAdapter(make_unique<CompatibleAdapter>("ArgMax" , OpSetID(11), OpSetID(12))); |
369 | registerAdapter(make_unique<CompatibleAdapter>("ArgMin" , OpSetID(11), OpSetID(12))); |
370 | registerAdapter(make_unique<CompatibleAdapter>("BatchNormalization" , OpSetID(11), OpSetID(12))); |
371 | registerAdapter(make_unique<CompatibleAdapter>("Constant" , OpSetID(11), OpSetID(12))); |
372 | registerAdapter(make_unique<CompatibleAdapter>("Clip" , OpSetID(11), OpSetID(12))); |
373 | registerAdapter(make_unique<CompatibleAdapter>("GatherND" , OpSetID(11), OpSetID(12))); |
374 | registerAdapter(make_unique<CompatibleAdapter>("Min" , OpSetID(11), OpSetID(12))); |
375 | registerAdapter(make_unique<CompatibleAdapter>("Max" , OpSetID(11), OpSetID(12))); |
376 | registerAdapter(make_unique<CompatibleAdapter>("MaxPool" , OpSetID(11), OpSetID(12))); |
377 | registerAdapter(make_unique<CompatibleAdapter>("Pow" , OpSetID(11), OpSetID(12))); |
378 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMax" , OpSetID(11), OpSetID(12))); |
379 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMin" , OpSetID(11), OpSetID(12))); |
380 | registerAdapter(make_unique<Dropout_11_12>()); |
381 | |
382 | /******** 12 -> 11 ********/ |
383 | std::vector<TensorProto_DataType> maxpool_unallowed_types = {TensorProto_DataType_UINT8, TensorProto_DataType_INT8}; |
384 | registerAdapter("ArgMax" , 12, 11, RemoveAttribute(kselect_last_index, 0)); |
385 | registerAdapter("ArgMin" , 12, 11, RemoveAttribute(kselect_last_index, 0)); |
386 | registerAdapter(make_unique<CompatibleAdapter>("BatchNormalization" , OpSetID(12), OpSetID(11))); |
387 | registerAdapter(make_unique<TypeRestriction>("Clip" , OpSetID(12), OpSetID(11), int_unallowed_types)); |
388 | registerAdapter(make_unique<TypeRestriction>("Min" , OpSetID(12), OpSetID(11), int_unallowed_types)); |
389 | registerAdapter(make_unique<TypeRestriction>("Max" , OpSetID(12), OpSetID(11), int_unallowed_types)); |
390 | registerAdapter(make_unique<TypeRestriction>("MaxPool" , OpSetID(12), OpSetID(11), maxpool_unallowed_types)); |
391 | registerAdapter(make_unique<TypeRestriction>("ReduceMax" , OpSetID(12), OpSetID(11), maxpool_unallowed_types)); |
392 | registerAdapter(make_unique<TypeRestriction>("ReduceMin" , OpSetID(12), OpSetID(11), maxpool_unallowed_types)); |
393 | |
394 | /******** 12 -> 13 ********/ |
395 | registerAdapter(make_unique<CompatibleAdapter>("Abs" , OpSetID(12), OpSetID(13))); |
396 | registerAdapter(make_unique<CompatibleAdapter>("Add" , OpSetID(12), OpSetID(13))); |
397 | registerAdapter(make_unique<CompatibleAdapter>("ArgMin" , OpSetID(12), OpSetID(13))); |
398 | registerAdapter(make_unique<CompatibleAdapter>("ArgMax" , OpSetID(12), OpSetID(13))); |
399 | registerAdapter(make_unique<CompatibleAdapter>("Cast" , OpSetID(12), OpSetID(13))); |
400 | registerAdapter(make_unique<CompatibleAdapter>("Ceil" , OpSetID(12), OpSetID(13))); |
401 | registerAdapter(make_unique<CompatibleAdapter>("Clip" , OpSetID(12), OpSetID(13))); |
402 | registerAdapter(make_unique<CompatibleAdapter>("Concat" , OpSetID(12), OpSetID(13))); |
403 | registerAdapter(make_unique<CompatibleAdapter>("Constant" , OpSetID(12), OpSetID(13))); |
404 | registerAdapter(make_unique<CompatibleAdapter>("DepthToSpace" , OpSetID(12), OpSetID(13))); |
405 | registerAdapter(make_unique<CompatibleAdapter>("DequantizeLinear" , OpSetID(12), OpSetID(13))); |
406 | registerAdapter(make_unique<CompatibleAdapter>("Div" , OpSetID(12), OpSetID(13))); |
407 | registerAdapter(make_unique<CompatibleAdapter>("Dropout" , OpSetID(12), OpSetID(13))); |
408 | registerAdapter(make_unique<CompatibleAdapter>("Equal" , OpSetID(12), OpSetID(13))); |
409 | registerAdapter(make_unique<CompatibleAdapter>("Erf" , OpSetID(12), OpSetID(13))); |
410 | registerAdapter(make_unique<CompatibleAdapter>("Exp" , OpSetID(12), OpSetID(13))); |
411 | registerAdapter(make_unique<CompatibleAdapter>("Expand" , OpSetID(12), OpSetID(13))); |
412 | registerAdapter(make_unique<CompatibleAdapter>("Flatten" , OpSetID(12), OpSetID(13))); |
413 | registerAdapter(make_unique<CompatibleAdapter>("Floor" , OpSetID(12), OpSetID(13))); |
414 | registerAdapter(make_unique<CompatibleAdapter>("Gather" , OpSetID(12), OpSetID(13))); |
415 | registerAdapter(make_unique<CompatibleAdapter>("GatherElements" , OpSetID(12), OpSetID(13))); |
416 | registerAdapter(make_unique<CompatibleAdapter>("GatherND" , OpSetID(12), OpSetID(13))); |
417 | registerAdapter(make_unique<CompatibleAdapter>("Gemm" , OpSetID(12), OpSetID(13))); |
418 | registerAdapter(make_unique<CompatibleAdapter>("Greater" , OpSetID(12), OpSetID(13))); |
419 | registerAdapter(make_unique<CompatibleAdapter>("Hardmax" , OpSetID(12), OpSetID(13))); |
420 | registerAdapter(make_unique<CompatibleAdapter>("Identity" , OpSetID(12), OpSetID(13))); |
421 | registerAdapter(make_unique<CompatibleAdapter>("If" , OpSetID(12), OpSetID(13))); |
422 | registerAdapter(make_unique<CompatibleAdapter>("IsNaN" , OpSetID(12), OpSetID(13))); |
423 | registerAdapter(make_unique<CompatibleAdapter>("Less" , OpSetID(12), OpSetID(13))); |
424 | registerAdapter(make_unique<CompatibleAdapter>("Log" , OpSetID(12), OpSetID(13))); |
425 | registerAdapter(make_unique<CompatibleAdapter>("Loop" , OpSetID(12), OpSetID(13))); |
426 | registerAdapter(make_unique<CompatibleAdapter>("LRN" , OpSetID(12), OpSetID(13))); |
427 | registerAdapter(make_unique<CompatibleAdapter>("NegativeLogLikelihoodLoss" , OpSetID(12), OpSetID(13))); |
428 | registerAdapter(make_unique<CompatibleAdapter>("MatMul" , OpSetID(12), OpSetID(13))); |
429 | registerAdapter(make_unique<CompatibleAdapter>("Max" , OpSetID(12), OpSetID(13))); |
430 | registerAdapter(make_unique<CompatibleAdapter>("Mean" , OpSetID(12), OpSetID(13))); |
431 | registerAdapter(make_unique<CompatibleAdapter>("MeanVarianceNormalization" , OpSetID(12), OpSetID(13))); |
432 | registerAdapter(make_unique<CompatibleAdapter>("Min" , OpSetID(12), OpSetID(13))); |
433 | registerAdapter(make_unique<CompatibleAdapter>("Mod" , OpSetID(12), OpSetID(13))); |
434 | registerAdapter(make_unique<CompatibleAdapter>("Mul" , OpSetID(12), OpSetID(13))); |
435 | registerAdapter(make_unique<CompatibleAdapter>("Neg" , OpSetID(12), OpSetID(13))); |
436 | registerAdapter(make_unique<CompatibleAdapter>("NonZero" , OpSetID(12), OpSetID(13))); |
437 | registerAdapter(make_unique<CompatibleAdapter>("Pow" , OpSetID(12), OpSetID(13))); |
438 | registerAdapter(make_unique<CompatibleAdapter>("Pad" , OpSetID(12), OpSetID(13))); |
439 | registerAdapter(make_unique<CompatibleAdapter>("QuantizeLinear" , OpSetID(12), OpSetID(13))); |
440 | registerAdapter(make_unique<CompatibleAdapter>("Reciprocal" , OpSetID(12), OpSetID(13))); |
441 | registerAdapter(make_unique<CompatibleAdapter>("ReduceL1" , OpSetID(12), OpSetID(13))); |
442 | registerAdapter(make_unique<CompatibleAdapter>("ReduceL2" , OpSetID(12), OpSetID(13))); |
443 | registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSum" , OpSetID(12), OpSetID(13))); |
444 | registerAdapter(make_unique<CompatibleAdapter>("ReduceLogSumExp" , OpSetID(12), OpSetID(13))); |
445 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMean" , OpSetID(12), OpSetID(13))); |
446 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMax" , OpSetID(12), OpSetID(13))); |
447 | registerAdapter(make_unique<CompatibleAdapter>("ReduceMin" , OpSetID(12), OpSetID(13))); |
448 | registerAdapter(make_unique<CompatibleAdapter>("ReduceProd" , OpSetID(12), OpSetID(13))); |
449 | registerAdapter(make_unique<CompatibleAdapter>("ReduceSumSquare" , OpSetID(12), OpSetID(13))); |
450 | registerAdapter(make_unique<CompatibleAdapter>("Relu" , OpSetID(12), OpSetID(13))); |
451 | registerAdapter(make_unique<CompatibleAdapter>("Reshape" , OpSetID(12), OpSetID(13))); |
452 | registerAdapter(make_unique<CompatibleAdapter>("Resize" , OpSetID(12), OpSetID(13))); |
453 | registerAdapter(make_unique<CompatibleAdapter>("ScatterElements" , OpSetID(12), OpSetID(13))); |
454 | registerAdapter(make_unique<CompatibleAdapter>("ScatterND" , OpSetID(12), OpSetID(13))); |
455 | registerAdapter(make_unique<CompatibleAdapter>("Shape" , OpSetID(12), OpSetID(13))); |
456 | registerAdapter(make_unique<CompatibleAdapter>("Sigmoid" , OpSetID(12), OpSetID(13))); |
457 | registerAdapter(make_unique<CompatibleAdapter>("Sign" , OpSetID(12), OpSetID(13))); |
458 | registerAdapter(make_unique<CompatibleAdapter>("Size" , OpSetID(12), OpSetID(13))); |
459 | registerAdapter(make_unique<CompatibleAdapter>("Slice" , OpSetID(12), OpSetID(13))); |
460 | registerAdapter(make_unique<CompatibleAdapter>("SoftmaxCrossEntropyLoss" , OpSetID(12), OpSetID(13))); |
461 | registerAdapter(make_unique<CompatibleAdapter>("SpaceToDepth" , OpSetID(12), OpSetID(13))); |
462 | registerAdapter(make_unique<CompatibleAdapter>("Sqrt" , OpSetID(12), OpSetID(13))); |
463 | registerAdapter(make_unique<CompatibleAdapter>("Sub" , OpSetID(12), OpSetID(13))); |
464 | registerAdapter(make_unique<CompatibleAdapter>("Sum" , OpSetID(12), OpSetID(13))); |
465 | registerAdapter(make_unique<CompatibleAdapter>("Tanh" , OpSetID(12), OpSetID(13))); |
466 | registerAdapter(make_unique<CompatibleAdapter>("Tile" , OpSetID(12), OpSetID(13))); |
467 | registerAdapter(make_unique<CompatibleAdapter>("Transpose" , OpSetID(12), OpSetID(13))); |
468 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceSum" , OpSetID(12), OpSetID(13))); |
469 | registerAdapter(make_unique<AxesAttributeToInput>("Squeeze" , OpSetID(12), OpSetID(13))); |
470 | registerAdapter(make_unique<AxesAttributeToInput>("Unsqueeze" , OpSetID(12), OpSetID(13))); |
471 | registerAdapter(make_unique<Split_12_13>()); |
472 | registerAdapter(make_unique<Softmax_12_13>("Softmax" )); |
473 | registerAdapter(make_unique<Softmax_12_13>("LogSoftmax" )); |
474 | |
475 | /******** 13 -> 12 ********/ |
476 | registerAdapter(make_unique<CompatibleAdapter>("Constant" , OpSetID(13), OpSetID(12))); |
477 | registerAdapter(make_unique<AxesInputToAttribute>("ReduceSum" , OpSetID(13), OpSetID(12))); |
478 | registerAdapter(make_unique<AxesInputToAttribute>("Squeeze" , OpSetID(13), OpSetID(12))); |
479 | registerAdapter(make_unique<AxesInputToAttribute>("Unsqueeze" , OpSetID(13), OpSetID(12))); |
480 | registerAdapter(make_unique<Split_13_12>()); |
481 | |
482 | /******** 13 -> 14 ********/ |
483 | registerAdapter(make_unique<CompatibleAdapter>("Add" , OpSetID(13), OpSetID(14))); |
484 | registerAdapter(make_unique<CompatibleAdapter>("CumSum" , OpSetID(13), OpSetID(14))); |
485 | registerAdapter(make_unique<CompatibleAdapter>("Div" , OpSetID(13), OpSetID(14))); |
486 | registerAdapter(make_unique<CompatibleAdapter>("Identity" , OpSetID(13), OpSetID(14))); |
487 | registerAdapter(make_unique<CompatibleAdapter>("Mul" , OpSetID(13), OpSetID(14))); |
488 | registerAdapter(make_unique<CompatibleAdapter>("Relu" , OpSetID(13), OpSetID(14))); |
489 | registerAdapter(make_unique<CompatibleAdapter>("Reshape" , OpSetID(13), OpSetID(14))); |
490 | registerAdapter(make_unique<CompatibleAdapter>("Sub" , OpSetID(13), OpSetID(14))); |
491 | registerAdapter("GRU" , 13, 14, SetAttribute(klayout, 0)); |
492 | registerAdapter("LSTM" , 13, 14, SetAttribute(klayout, 0)); |
493 | registerAdapter("RNN" , 13, 14, SetAttribute(klayout, 0)); |
494 | registerAdapter(make_unique<BatchNormalization_13_14>()); |
495 | |
496 | /******** 14 -> 13 ********/ |
497 | registerAdapter("GRU" , 14, 13, RemoveAttribute(klayout, 0)); |
498 | registerAdapter("LSTM" , 14, 13, RemoveAttribute(klayout, 0)); |
499 | registerAdapter("RNN" , 14, 13, RemoveAttribute(klayout, 0)); |
500 | |
501 | /******** 14 -> 15 ********/ |
502 | registerAdapter(make_unique<CompatibleAdapter>("BatchNormalization" , OpSetID(14), OpSetID(15))); |
503 | registerAdapter(make_unique<CompatibleAdapter>("Pow" , OpSetID(14), OpSetID(15))); |
504 | registerAdapter(make_unique<CompatibleAdapter>("Shape" , OpSetID(14), OpSetID(15))); |
505 | |
506 | /******** 15 -> 16 ********/ |
507 | registerAdapter("RoiAlign" , 15, 16, SetAttribute(kcoordinate_transformation_mode, "output_half_pixel" )); |
508 | registerAdapter(make_unique<CompatibleAdapter>("ScatterElements" , OpSetID(15), OpSetID(16))); |
509 | registerAdapter(make_unique<CompatibleAdapter>("ScatterND" , OpSetID(15), OpSetID(16))); |
510 | registerAdapter(make_unique<CompatibleAdapter>("Identity" , OpSetID(15), OpSetID(16))); |
511 | registerAdapter(make_unique<CompatibleAdapter>("Loop" , OpSetID(15), OpSetID(16))); |
512 | registerAdapter(make_unique<CompatibleAdapter>("If" , OpSetID(15), OpSetID(16))); |
513 | registerAdapter(make_unique<CompatibleAdapter>("Where" , OpSetID(15), OpSetID(16))); |
514 | registerAdapter(make_unique<CompatibleAdapter>("Scan" , OpSetID(15), OpSetID(16))); |
515 | registerAdapter(make_unique<CompatibleAdapter>("LessOrEqual" , OpSetID(15), OpSetID(16))); |
516 | registerAdapter(make_unique<CompatibleAdapter>("GreaterOrEqual" , OpSetID(15), OpSetID(16))); |
517 | registerAdapter(make_unique<CompatibleAdapter>("LeakyRelu" , OpSetID(15), OpSetID(16))); |
518 | registerAdapter(make_unique<CompatibleAdapter>("PRelu" , OpSetID(15), OpSetID(16))); |
519 | |
520 | /******** 17 -> 18 ********/ |
521 | registerAdapter(make_unique<CompatibleAdapter>("Pad" , OpSetID(17), OpSetID(18))); |
522 | registerAdapter(make_unique<CompatibleAdapter>("Resize" , OpSetID(17), OpSetID(18))); |
523 | registerAdapter(make_unique<CompatibleAdapter>("OptionalGetElement" , OpSetID(17), OpSetID(18))); |
524 | registerAdapter(make_unique<CompatibleAdapter>("OptionalHasElement" , OpSetID(17), OpSetID(18))); |
525 | registerAdapter(make_unique<Split_17_18>()); |
526 | registerAdapter(make_unique<CompatibleAdapter>("ScatterND" , OpSetID(17), OpSetID(18))); |
527 | registerAdapter(make_unique<CompatibleAdapter>("ScatterElements" , OpSetID(17), OpSetID(18))); |
528 | registerAdapter("LpPool" , 17, 18, SetAttribute(kceil_mode, 0)); |
529 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceL1" , OpSetID(17), OpSetID(18))); |
530 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceL2" , OpSetID(17), OpSetID(18))); |
531 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceLogSum" , OpSetID(17), OpSetID(18))); |
532 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceLogSumExp" , OpSetID(17), OpSetID(18))); |
533 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceMax" , OpSetID(17), OpSetID(18))); |
534 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceMean" , OpSetID(17), OpSetID(18))); |
535 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceMin" , OpSetID(17), OpSetID(18))); |
536 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceProd" , OpSetID(17), OpSetID(18))); |
537 | registerAdapter(make_unique<AxesAttributeToInput>("ReduceSumSquare" , OpSetID(17), OpSetID(18))); |
538 | } |
539 | |
540 | ModelProto convert_version(const ModelProto& mp_in, const OpSetID& initial_version, const OpSetID& target_version) |
541 | const override; |
542 | }; |
543 | |
544 | ModelProto ConvertVersion(const ModelProto& mp_in, int target_version); |
545 | } // namespace version_conversion |
546 | } // namespace ONNX_NAMESPACE |
547 | |