1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. |
6 | // Adventurous users should note that the APIs will probably change. |
7 | |
8 | #include "onnx/common/ir_pb_converter.h" |
9 | #include <sstream> |
10 | |
11 | namespace ONNX_NAMESPACE { |
12 | |
13 | // Part 1: convert ONNX Protobuf to IR |
14 | std::unique_ptr<Graph> graphProtoToGraph(const GraphProto& gp, bool nested, const int ir_version = IR_VERSION); |
15 | |
16 | Tensor tensorProtoToTensor(const ONNX_NAMESPACE::TensorProto& tp) { |
17 | Tensor ret; |
18 | |
19 | ret.sizes().reserve(tp.dims_size()); |
20 | for (int i = 0; i < tp.dims_size(); i++) { |
21 | ret.sizes().push_back(tp.dims(i)); |
22 | } |
23 | |
24 | ret.elem_type() = tp.data_type(); |
25 | switch (tp.data_type()) { |
26 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: |
27 | case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64: { |
28 | ret.floats().reserve(tp.float_data_size()); |
29 | for (int i = 0; i < tp.float_data_size(); i++) { |
30 | ret.floats().push_back(tp.float_data(i)); |
31 | } |
32 | break; |
33 | } |
34 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: |
35 | case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: |
36 | case ONNX_NAMESPACE::TensorProto_DataType_BOOL: |
37 | case ONNX_NAMESPACE::TensorProto_DataType_INT8: |
38 | case ONNX_NAMESPACE::TensorProto_DataType_INT16: |
39 | case ONNX_NAMESPACE::TensorProto_DataType_INT32: |
40 | case ONNX_NAMESPACE::TensorProto_DataType_UINT8: |
41 | case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { |
42 | ret.int32s().reserve(tp.int32_data_size()); |
43 | for (int i = 0; i < tp.int32_data_size(); i++) { |
44 | ret.int32s().push_back(tp.int32_data(i)); |
45 | } |
46 | break; |
47 | } |
48 | case ONNX_NAMESPACE::TensorProto_DataType_INT64: { |
49 | ret.int64s().reserve(tp.int64_data_size()); |
50 | for (int i = 0; i < tp.int64_data_size(); i++) { |
51 | ret.int64s().push_back(tp.int64_data(i)); |
52 | } |
53 | break; |
54 | } |
55 | case ONNX_NAMESPACE::TensorProto_DataType_UINT32: |
56 | case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { |
57 | ret.uint64s().reserve(tp.uint64_data_size()); |
58 | for (int i = 0; i < tp.uint64_data_size(); i++) { |
59 | ret.uint64s().push_back(tp.uint64_data(i)); |
60 | } |
61 | break; |
62 | } |
63 | case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: |
64 | case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128: { |
65 | ret.doubles().reserve(tp.double_data_size()); |
66 | for (int i = 0; i < tp.double_data_size(); i++) { |
67 | ret.doubles().push_back(tp.double_data(i)); |
68 | } |
69 | break; |
70 | } |
71 | case ONNX_NAMESPACE::TensorProto_DataType_STRING: { |
72 | ret.strings().reserve(tp.string_data_size()); |
73 | for (int i = 0; i < tp.string_data_size(); i++) { |
74 | ret.strings().push_back(tp.string_data(i)); |
75 | } |
76 | break; |
77 | } |
78 | case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: |
79 | fail_convert("Unknown tensor data type" ); |
80 | } |
81 | |
82 | // The only way to know if we should be using raw_data or |
83 | // <type>_data is to look at which of them is size zero. |
84 | if (tp.has_raw_data()) { |
85 | ret.set_raw_data(tp.raw_data()); |
86 | } |
87 | |
88 | if (tp.has_name()) { |
89 | ret.setName(tp.name()); |
90 | } |
91 | if (tp.has_segment()) { |
92 | ret.set_segment_begin_and_end(tp.segment().begin(), tp.segment().end()); |
93 | } |
94 | return ret; |
95 | } |
96 | |
97 | void convertAttribute(const ONNX_NAMESPACE::AttributeProto& ap, Node* n, const int ir_version = IR_VERSION) { |
98 | Symbol sym = Symbol(ap.name()); |
99 | switch (ap.type()) { |
100 | case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT: |
101 | n->f_(sym, ap.f()); |
102 | break; |
103 | case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS: { |
104 | std::vector<double> floats; |
105 | floats.reserve(ap.floats_size()); |
106 | for (int i = 0; i < ap.floats_size(); i++) { |
107 | floats.push_back(ap.floats(i)); |
108 | } |
109 | n->fs_(sym, std::move(floats)); |
110 | break; |
111 | } |
112 | case ONNX_NAMESPACE::AttributeProto_AttributeType_INT: |
113 | n->i_(sym, ap.i()); |
114 | break; |
115 | case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS: { |
116 | std::vector<int64_t> ints; |
117 | ints.reserve(ap.ints_size()); |
118 | for (int i = 0; i < ap.ints_size(); i++) { |
119 | ints.push_back(ap.ints(i)); |
120 | } |
121 | n->is_(sym, std::move(ints)); |
122 | break; |
123 | } |
124 | case ONNX_NAMESPACE::AttributeProto_AttributeType_STRING: |
125 | n->s_(sym, ap.s()); |
126 | break; |
127 | case ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS: { |
128 | std::vector<std::string> strings; |
129 | strings.reserve(ap.strings_size()); |
130 | for (int i = 0; i < ap.strings_size(); i++) { |
131 | strings.push_back(ap.strings(i)); |
132 | } |
133 | n->ss_(sym, std::move(strings)); |
134 | break; |
135 | } |
136 | case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR: |
137 | n->t_(sym, tensorProtoToTensor(ap.t())); |
138 | break; |
139 | case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS: { |
140 | std::vector<Tensor> tensors; |
141 | tensors.reserve(ap.tensors_size()); |
142 | for (int i = 0; i < ap.tensors_size(); i++) { |
143 | tensors.push_back(tensorProtoToTensor(ap.tensors(i))); |
144 | } |
145 | n->ts_(sym, std::move(tensors)); |
146 | break; |
147 | } |
148 | case ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTO: |
149 | n->tp_(sym, ap.tp()); |
150 | break; |
151 | case ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTOS: { |
152 | std::vector<TypeProto> types; |
153 | types.reserve(ap.type_protos_size()); |
154 | for (int i = 0; i < ap.type_protos_size(); i++) { |
155 | types.push_back(ap.type_protos(i)); |
156 | } |
157 | n->tps_(sym, std::move(types)); |
158 | break; |
159 | } |
160 | case ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH: |
161 | n->g_(sym, graphProtoToGraph(ap.g(), true, ir_version)); |
162 | break; |
163 | case ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS: { |
164 | std::vector<std::shared_ptr<Graph>> graphs; |
165 | graphs.reserve(ap.graphs_size()); |
166 | for (int i = 0; i < ap.graphs_size(); i++) { |
167 | graphs.push_back(graphProtoToGraph(ap.graphs(i), true, ir_version)); |
168 | } |
169 | n->gs_(sym, std::move(graphs)); |
170 | break; |
171 | } |
172 | case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR: |
173 | case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSORS: |
174 | fail_convert("Sparse tensors not supported." ); |
175 | break; |
176 | case ONNX_NAMESPACE::AttributeProto_AttributeType_UNDEFINED: |
177 | fail_convert("Unknown tensor data type" ); |
178 | break; |
179 | } |
180 | } |
181 | |
182 | void convertAttributes(ONNX_NAMESPACE::NodeProto& np, Node* n, const int ir_version = IR_VERSION) { |
183 | for (int i = 0; i < np.attribute_size(); i++) { |
184 | convertAttribute(np.attribute(i), n, ir_version); |
185 | } |
186 | } |
187 | |
188 | std::vector<Dimension> tensorShapeProtoToDimensions(const ONNX_NAMESPACE::TensorShapeProto& tsp) { |
189 | std::vector<Dimension> dims; |
190 | dims.reserve(tsp.dim_size()); |
191 | for (int i = 0; i < tsp.dim_size(); i++) { |
192 | if (tsp.dim(i).has_dim_value()) { |
193 | dims.emplace_back(tsp.dim(i).dim_value()); |
194 | } else if (tsp.dim(i).has_dim_param()) { |
195 | dims.emplace_back(tsp.dim(i).dim_param()); |
196 | } else { |
197 | // a dimension that has neither dim_value nor dim_param set |
198 | // represents an unknown dimension unrelated to other unknown dimensions. |
199 | dims.emplace_back(); |
200 | } |
201 | } |
202 | return dims; |
203 | } |
204 | |
205 | void createDummyValue( |
206 | std::unique_ptr<Graph>& g, |
207 | const std::string& name, |
208 | std::unordered_map<std::string, Value*>& value_by_name_of) { |
209 | auto* undef = g->create(kCaptured, 1); |
210 | g->appendNode(undef); |
211 | undef->outputs()[0]->setUniqueName(name); |
212 | value_by_name_of[name] = undef->outputs()[0]; |
213 | } |
214 | |
215 | std::unique_ptr<Graph> graphProtoToGraph(const ONNX_NAMESPACE::GraphProto& gp, bool nested, const int ir_version) { |
216 | std::unique_ptr<Graph> g(new Graph()); |
217 | |
218 | if (gp.has_name()) { |
219 | g->setName(gp.name()); |
220 | } |
221 | if (gp.has_doc_string()) { |
222 | g->setDocString(gp.doc_string()); |
223 | } |
224 | |
225 | // Values are created (as in `new Value(..)`) by the Node that |
226 | // outputs them. Therefore we initialize the Nodes and Values in |
227 | // several stages. |
228 | // |
229 | // 1) add all input (to the graph) Values, owned by the sentinel Param node |
230 | // 2) add all Nodes and their output Values, but don't intialize inputs |
231 | // 3) initialize inputs of all Nodes |
232 | // 4) initialize inputs of the Return sentinel node |
233 | // 5) fill in type info for graph outputs, and register them as outputs |
234 | // 6) fill in type info for Values from the value_info list in the graph |
235 | |
236 | // In ONNX proto land, Values are just strings. We are going to make |
237 | // objects out of them, and equal strings must be mapped to the same |
238 | // Value object. |
239 | std::unordered_map<std::string, Value*> value_by_name_of; |
240 | |
241 | // We initialize Node inputs in a separate pass from the Nodes |
242 | // themselves. To do so, we need to have access to the names of the |
243 | // inputs. |
244 | std::unordered_map<Node*, std::vector<std::string>> inputs_by_node; |
245 | |
246 | { |
247 | // ONNX represents optional arguments in two ways |
248 | // - they are simply not provided |
249 | // - OR the empty string is passed as the input name |
250 | // This is to handle that second case, which needs a dummy node to |
251 | // be representable in the graph IR. |
252 | auto* n = g->create(kUndefined, 1); |
253 | g->appendNode(n); |
254 | n->outputs()[0]->setUniqueName("" ); |
255 | value_by_name_of["" ] = n->outputs()[0]; |
256 | } |
257 | |
258 | for (int i = 0; i < gp.input_size(); i++) { |
259 | const auto& vip = gp.input(i); |
260 | auto v = g->addInput(); |
261 | const auto& tensor_type = vip.type().tensor_type(); |
262 | if (tensor_type.has_elem_type()) { |
263 | v->setElemType(tensor_type.elem_type()); |
264 | } |
265 | if (tensor_type.has_shape()) { |
266 | v->setSizes(tensorShapeProtoToDimensions(tensor_type.shape())); |
267 | } |
268 | v->setUniqueName(vip.name()); |
269 | value_by_name_of[vip.name()] = v; |
270 | } |
271 | |
272 | // initializers should be added before all nodes, |
273 | // otherwise getNextUnique() may conflicts with an existing initializer name. |
274 | for (int i = 0; i < gp.initializer_size(); ++i) { |
275 | auto init = tensorProtoToTensor(gp.initializer(i)); |
276 | // If ir_version >= 4, initializer does not have to be included in input |
277 | // Create a Value from initializer by addInitializerNode if name does not exist in input |
278 | // and save it into value_by_name_of for later use (node input) |
279 | if (ir_version >= 4 && value_by_name_of.count(init.name()) == 0) { |
280 | value_by_name_of[init.name()] = g->addInitializerAndCreateValue(init); |
281 | } else { |
282 | // If ir_version < 4 or the initializer exists in input |
283 | // Simply add initializer without creating new value |
284 | // which means it will prioritize input value over initializer value if both exist |
285 | g->addInitializer(init); |
286 | } |
287 | } |
288 | |
289 | for (int i = 0; i < gp.node_size(); i++) { |
290 | auto np = gp.node(i); |
291 | auto* n = g->create(Symbol(np.op_type()), /* num_outputs = */ np.output_size()); |
292 | g->appendNode(n); |
293 | for (int j = 0; j < np.output_size(); j++) { |
294 | auto out = n->outputs()[j]; |
295 | // we don't know the real type here, so that's done in a later pass |
296 | out->setElemType(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); |
297 | out->setUniqueName(np.output(j)); |
298 | value_by_name_of[np.output(j)] = out; |
299 | } |
300 | convertAttributes(np, n, ir_version); |
301 | std::vector<std::string> inputs; |
302 | inputs.reserve(np.input_size()); |
303 | for (int j = 0; j < np.input_size(); j++) { |
304 | inputs.push_back(np.input(j)); |
305 | } |
306 | inputs_by_node[n] = inputs; |
307 | if (np.has_doc_string()) { |
308 | n->setDocString(np.doc_string()); |
309 | } |
310 | if (np.has_name()) { |
311 | n->setName(np.name()); |
312 | } |
313 | if (np.has_domain()) { |
314 | n->setDomain(np.domain()); |
315 | } |
316 | } |
317 | |
318 | for (auto n : g->nodes()) { |
319 | auto search = inputs_by_node.find(n); |
320 | if (search == inputs_by_node.end()) { |
321 | continue; |
322 | } |
323 | for (const auto& input : search->second) { |
324 | if (!value_by_name_of.count(input) && nested) { |
325 | // Undefined reference to an input in a nested block. This may be a |
326 | // captured value. Create a dummy node that we ignore later. |
327 | createDummyValue(g, input, value_by_name_of); |
328 | } |
329 | |
330 | if (!value_by_name_of.count(input)) { |
331 | std::ostringstream msg; |
332 | msg << "Input " << input << " is undefined!" ; |
333 | ONNX_THROW_EX(std::out_of_range(msg.str())); |
334 | } |
335 | n->addInput(value_by_name_of.at(input)); |
336 | } |
337 | } |
338 | |
339 | for (int i = 0; i < gp.output_size(); i++) { |
340 | if (!value_by_name_of.count(gp.output(i).name()) && nested) { |
341 | // Same captured value logic as above. We can consider outputs of a |
342 | // graph to be "inputs" of a dummy "output" node. The same lexical |
343 | // scoping rules are valid here, thus we need to add a dummy node |
344 | // in the case of an undefined reference |
345 | createDummyValue(g, gp.output(i).name(), value_by_name_of); |
346 | } |
347 | const auto& output_tensor_type = gp.output(i).type().tensor_type(); |
348 | if (output_tensor_type.has_elem_type()) { |
349 | value_by_name_of[gp.output(i).name()]->setElemType(output_tensor_type.elem_type()); |
350 | } |
351 | if (output_tensor_type.has_shape()) { |
352 | value_by_name_of[gp.output(i).name()]->setSizes(tensorShapeProtoToDimensions(output_tensor_type.shape())); |
353 | } |
354 | g->registerOutput(value_by_name_of[gp.output(i).name()]); |
355 | } |
356 | |
357 | for (int i = 0; i < gp.value_info_size(); i++) { |
358 | const auto& tensor_type = gp.value_info(i).type().tensor_type(); |
359 | if (!value_by_name_of.count(gp.value_info(i).name())) { |
360 | // Ideally the model should not have a value_info whose name does not exist in the graph (unused); simply skip it |
361 | continue; |
362 | } |
363 | if (tensor_type.has_elem_type()) { |
364 | value_by_name_of[gp.value_info(i).name()]->setElemType(tensor_type.elem_type()); |
365 | } |
366 | if (tensor_type.has_shape()) { |
367 | value_by_name_of[gp.value_info(i).name()]->setSizes(tensorShapeProtoToDimensions(tensor_type.shape())); |
368 | } |
369 | } |
370 | |
371 | return g; |
372 | } |
373 | |
374 | std::unique_ptr<Graph> ImportModelProto(const ModelProto& mp) { |
375 | if (!mp.has_ir_version()) { |
376 | return nullptr; |
377 | } else if (mp.ir_version() <= 1) { |
378 | // ir_version=1 is not supported and ir_version=0 is illegal |
379 | return nullptr; |
380 | } |
381 | |
382 | std::unique_ptr<Graph> g(graphProtoToGraph(mp.graph(), false, mp.ir_version())); |
383 | for (int i = 0; i < mp.opset_import_size(); i++) { |
384 | OpSetID new_opset_version(mp.opset_import(i).domain(), mp.opset_import(i).version()); |
385 | g->forSelfAndEachSubGraph( |
386 | [&new_opset_version](Graph* graph) { graph->opset_versions_mutable().emplace_back(new_opset_version); }); |
387 | } |
388 | return g; |
389 | } |
390 | |
391 | // Part 2: convert IR to ONNX Protobuf |
392 | std::string value_name(Value* n) { |
393 | return n->uniqueName(); |
394 | } |
395 | |
396 | void encodeGraph(GraphProto* p_g, const std::shared_ptr<Graph>& g); |
397 | |
398 | void encodeTensor(ONNX_NAMESPACE::TensorProto* p, const Tensor& tensor) { |
399 | if (tensor.hasName()) { |
400 | p->set_name(tensor.name()); |
401 | } |
402 | if (tensor.is_segment()) { |
403 | ONNX_NAMESPACE::TensorProto_Segment segment; |
404 | segment.set_begin(tensor.segment_begin()); |
405 | segment.set_end(tensor.segment_end()); |
406 | p->mutable_segment()->CopyFrom(segment); |
407 | } |
408 | for (auto d : tensor.sizes()) { |
409 | p->add_dims(d); |
410 | } |
411 | p->set_data_type(tensor.elem_type()); |
412 | switch (tensor.elem_type()) { |
413 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: |
414 | case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64: { |
415 | for (float x : tensor.floats()) { |
416 | p->add_float_data(x); |
417 | } |
418 | break; |
419 | } |
420 | case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: |
421 | case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: |
422 | case ONNX_NAMESPACE::TensorProto_DataType_BOOL: |
423 | case ONNX_NAMESPACE::TensorProto_DataType_INT8: |
424 | case ONNX_NAMESPACE::TensorProto_DataType_INT16: |
425 | case ONNX_NAMESPACE::TensorProto_DataType_INT32: |
426 | case ONNX_NAMESPACE::TensorProto_DataType_UINT8: |
427 | case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { |
428 | for (int32_t x : tensor.int32s()) { |
429 | p->add_int32_data(x); |
430 | } |
431 | break; |
432 | } |
433 | case ONNX_NAMESPACE::TensorProto_DataType_INT64: { |
434 | for (int64_t x : tensor.int64s()) { |
435 | p->add_int64_data(x); |
436 | } |
437 | break; |
438 | } |
439 | case ONNX_NAMESPACE::TensorProto_DataType_UINT32: |
440 | case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { |
441 | for (uint64_t x : tensor.uint64s()) { |
442 | p->add_uint64_data(x); |
443 | } |
444 | break; |
445 | } |
446 | case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: |
447 | case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128: { |
448 | for (double x : tensor.doubles()) { |
449 | p->add_double_data(x); |
450 | } |
451 | break; |
452 | } |
453 | case ONNX_NAMESPACE::TensorProto_DataType_STRING: { |
454 | for (const std::string& x : tensor.strings()) { |
455 | p->add_string_data(x); |
456 | } |
457 | break; |
458 | } |
459 | case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: |
460 | fail_convert("Unknown tensor data type" ); |
461 | } |
462 | if (tensor.is_raw_data()) { |
463 | p->set_raw_data(tensor.raw()); |
464 | } |
465 | } |
466 | |
467 | void addAttribute(ONNX_NAMESPACE::NodeProto* n_p, Node* n, Symbol name) { |
468 | auto attr = n_p->add_attribute(); |
469 | attr->set_name(name.toString()); |
470 | switch (n->kindOf(name)) { |
471 | case AttributeKind::f: { |
472 | attr->set_f(static_cast<float>(n->f(name))); |
473 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); |
474 | } break; |
475 | case AttributeKind::fs: { |
476 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS); |
477 | for (auto& v : n->fs(name)) |
478 | attr->add_floats(static_cast<float>(v)); |
479 | } break; |
480 | case AttributeKind::i: { |
481 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); |
482 | attr->set_i(n->i(name)); |
483 | } break; |
484 | case AttributeKind::is: { |
485 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INTS); |
486 | for (auto& v : n->is(name)) |
487 | attr->add_ints(v); |
488 | } break; |
489 | case AttributeKind::s: { |
490 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); |
491 | attr->set_s(n->s(name)); |
492 | } break; |
493 | case AttributeKind::ss: { |
494 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS); |
495 | for (auto& v : n->ss(name)) |
496 | attr->add_strings(v); |
497 | } break; |
498 | case AttributeKind::t: { |
499 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR); |
500 | auto t = attr->mutable_t(); |
501 | encodeTensor(t, n->t(name)); |
502 | } break; |
503 | case AttributeKind::ts: { |
504 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS); |
505 | for (auto& v : n->ts(name)) { |
506 | auto t = attr->add_tensors(); |
507 | encodeTensor(t, v); |
508 | } |
509 | } break; |
510 | case AttributeKind::g: { |
511 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH); |
512 | auto g = attr->mutable_g(); |
513 | encodeGraph(g, n->g(name)); |
514 | } break; |
515 | case AttributeKind::gs: { |
516 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS); |
517 | for (auto& v : n->gs(name)) { |
518 | auto g = attr->add_graphs(); |
519 | encodeGraph(g, v); |
520 | } |
521 | } break; |
522 | case AttributeKind::tp: { |
523 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTO); |
524 | auto tp = attr->mutable_tp(); |
525 | tp->CopyFrom(n->tp(name)); |
526 | } break; |
527 | case AttributeKind::tps: { |
528 | attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTOS); |
529 | for (auto& v : n->tps(name)) { |
530 | auto tp = attr->add_type_protos(); |
531 | tp->CopyFrom(v); |
532 | } |
533 | } break; |
534 | } |
535 | } |
536 | |
537 | void encodeTypeProtoTensorType(ONNX_NAMESPACE::TypeProto_Tensor* tensor_type, Value* n) { |
538 | if (n->elemType() != 0) { |
539 | tensor_type->set_elem_type(n->elemType()); |
540 | } |
541 | if (n->has_sizes()) { |
542 | ONNX_NAMESPACE::TensorShapeProto* shape = tensor_type->mutable_shape(); |
543 | for (const Dimension& d : n->sizes()) { |
544 | auto dim = shape->add_dim(); |
545 | if (!d.is_unknown) { |
546 | if (d.is_int) { |
547 | dim->set_dim_value(d.dim); |
548 | } else { |
549 | dim->set_dim_param(d.param); |
550 | } |
551 | } |
552 | } |
553 | } |
554 | } |
555 | |
556 | void encodeValueInfo(ONNX_NAMESPACE::ValueInfoProto* v, Value* n) { |
557 | v->set_name(value_name(n)); |
558 | if (n->elemType() != 0 || n->has_sizes()) { |
559 | ONNX_NAMESPACE::TypeProto* t = v->mutable_type(); |
560 | ONNX_NAMESPACE::TypeProto_Tensor* tensor_type = t->mutable_tensor_type(); |
561 | encodeTypeProtoTensorType(tensor_type, n); |
562 | } |
563 | } |
564 | |
565 | void encodeGraph(GraphProto* p_g, const std::shared_ptr<Graph>& g) { |
566 | ONNX_ASSERT(p_g != nullptr); |
567 | |
568 | if (g->has_name()) { |
569 | p_g->set_name(g->name()); |
570 | } |
571 | |
572 | if (g->has_doc_string()) { |
573 | p_g->set_doc_string(g->docString()); |
574 | } |
575 | |
576 | for (auto input : g->inputs()) { |
577 | ONNX_NAMESPACE::ValueInfoProto* v = p_g->add_input(); |
578 | encodeValueInfo(v, input); |
579 | } |
580 | for (auto output : g->outputs()) { |
581 | ONNX_NAMESPACE::ValueInfoProto* v = p_g->add_output(); |
582 | encodeValueInfo(v, output); |
583 | } |
584 | |
585 | std::unordered_set<Value*> graph_outputs(g->outputs().begin(), g->outputs().end()); |
586 | |
587 | for (auto node : g->nodes()) { |
588 | if (node->kind() == kUndefined || node->kind() == kCaptured) { |
589 | // Undefined nodes are used to represent optional inputs that are not |
590 | // provided. |
591 | continue; |
592 | } |
593 | auto p_n = p_g->add_node(); |
594 | for (auto input : node->inputs()) { |
595 | if (input->node()->kind() == kUndefined) { |
596 | p_n->add_input("" ); |
597 | } else { |
598 | p_n->add_input(value_name(input)); |
599 | } |
600 | } |
601 | for (auto output : node->outputs()) { |
602 | p_n->add_output(value_name(output)); |
603 | // only save it if |
604 | // - it has actual information worth saving |
605 | // - it's not already saved in the graph outputs value info |
606 | if (graph_outputs.find(output) != graph_outputs.end()) { |
607 | continue; |
608 | } |
609 | if (output->elemType() == TensorProto_DataType_UNDEFINED && output->sizes().empty()) { |
610 | continue; |
611 | } |
612 | ValueInfoProto* v = p_g->add_value_info(); |
613 | encodeValueInfo(v, output); |
614 | } |
615 | p_n->set_op_type(node->kind().toString()); |
616 | for (auto attr_name : node->attributeNames()) { |
617 | addAttribute(p_n, node, attr_name); |
618 | } |
619 | if (node->has_doc_string()) { |
620 | p_n->set_doc_string(node->docString()); |
621 | } |
622 | if (node->has_name()) { |
623 | p_n->set_name(node->name()); |
624 | } |
625 | if (node->has_domain()) { |
626 | p_n->set_domain(node->domain()); |
627 | } |
628 | } |
629 | |
630 | auto num_initializers = g->initializers().size(); |
631 | for (unsigned int i = 0; i < num_initializers; i++) { |
632 | auto p = p_g->add_initializer(); |
633 | p->set_name(g->initializer_names()[i]); |
634 | encodeTensor(p, g->initializers()[i]); |
635 | } |
636 | } |
637 | |
638 | void ExportModelProto(ModelProto* p_m, const std::shared_ptr<Graph>& g) { |
639 | GraphProto* p_g = p_m->mutable_graph(); |
640 | encodeGraph(p_g, g); |
641 | // Add new opset_versions |
642 | p_m->clear_opset_import(); |
643 | for (const OpSetID& opset : g->opset_versions_mutable()) { |
644 | OperatorSetIdProto* opset_version_output = p_m->add_opset_import(); |
645 | opset_version_output->set_domain(opset.domain()); |
646 | opset_version_output->set_version(opset.version()); |
647 | } |
648 | } |
649 | |
650 | ModelProto PrepareOutput(const ModelProto& mp_in) { |
651 | ModelProto mp_out{}; |
652 | |
653 | if (mp_in.has_ir_version()) { |
654 | mp_out.set_ir_version(mp_in.ir_version()); |
655 | } |
656 | if (mp_in.has_producer_name()) { |
657 | mp_out.set_producer_name(mp_in.producer_name()); |
658 | } |
659 | if (mp_in.has_producer_version()) { |
660 | mp_out.set_producer_version(mp_in.producer_version()); |
661 | } |
662 | if (mp_in.has_domain()) { |
663 | mp_out.set_domain(mp_in.domain()); |
664 | } |
665 | if (mp_in.has_model_version()) { |
666 | mp_out.set_model_version(mp_in.model_version()); |
667 | } |
668 | if (mp_in.has_doc_string()) { |
669 | mp_out.set_doc_string(mp_in.doc_string()); |
670 | } |
671 | for (int i = 0; i < mp_in.opset_import_size(); i++) { |
672 | auto& oi_in = mp_in.opset_import(i); |
673 | auto* oi_out = mp_out.add_opset_import(); |
674 | if (oi_in.has_domain()) { |
675 | oi_out->set_domain(oi_in.domain()); |
676 | } |
677 | if (oi_in.has_version()) { |
678 | oi_out->set_version(oi_in.version()); |
679 | } |
680 | } |
681 | for (int i = 0; i < mp_in.metadata_props_size(); i++) { |
682 | auto& pp_in = mp_in.metadata_props(i); |
683 | auto* pp_out = mp_out.add_metadata_props(); |
684 | if (pp_in.has_key()) { |
685 | pp_out->set_key(pp_in.key()); |
686 | } |
687 | if (pp_in.has_value()) { |
688 | pp_out->set_value(pp_in.value()); |
689 | } |
690 | } |
691 | |
692 | return mp_out; |
693 | } |
694 | |
695 | void assertNonNull(const std::shared_ptr<Graph>& g) { |
696 | ONNX_ASSERTM( |
697 | g.get() != nullptr, |
698 | "Warning: onnx version converter is unable to parse input model. " |
699 | "(The IR version of the ONNX model may be too old.)" ); |
700 | } |
701 | |
702 | } // namespace ONNX_NAMESPACE |
703 | |